mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-31 11:48:31 +00:00
fix: add ConfigDict for Pydantic model_config and ClassVar annotations
This commit is contained in:
@@ -1,17 +1,10 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||||
@@ -19,6 +12,24 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
|||||||
from crewai.agents import CacheHandler
|
from crewai.agents import CacheHandler
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
AgentExecutionCompletedEvent,
|
||||||
|
AgentExecutionErrorEvent,
|
||||||
|
AgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.events.types.knowledge_events import (
|
||||||
|
KnowledgeQueryCompletedEvent,
|
||||||
|
KnowledgeQueryFailedEvent,
|
||||||
|
KnowledgeQueryStartedEvent,
|
||||||
|
KnowledgeRetrievalCompletedEvent,
|
||||||
|
KnowledgeRetrievalStartedEvent,
|
||||||
|
KnowledgeSearchQueryFailedEvent,
|
||||||
|
)
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryRetrievalCompletedEvent,
|
||||||
|
MemoryRetrievalStartedEvent,
|
||||||
|
)
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
@@ -38,24 +49,6 @@ from crewai.utilities.agent_utils import (
|
|||||||
)
|
)
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
from crewai.events.types.agent_events import (
|
|
||||||
AgentExecutionCompletedEvent,
|
|
||||||
AgentExecutionErrorEvent,
|
|
||||||
AgentExecutionStartedEvent,
|
|
||||||
)
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.types.memory_events import (
|
|
||||||
MemoryRetrievalStartedEvent,
|
|
||||||
MemoryRetrievalCompletedEvent,
|
|
||||||
)
|
|
||||||
from crewai.events.types.knowledge_events import (
|
|
||||||
KnowledgeQueryCompletedEvent,
|
|
||||||
KnowledgeQueryFailedEvent,
|
|
||||||
KnowledgeQueryStartedEvent,
|
|
||||||
KnowledgeRetrievalCompletedEvent,
|
|
||||||
KnowledgeRetrievalStartedEvent,
|
|
||||||
KnowledgeSearchQueryFailedEvent,
|
|
||||||
)
|
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
@@ -87,36 +80,36 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_times_executed: int = PrivateAttr(default=0)
|
_times_executed: int = PrivateAttr(default=0)
|
||||||
max_execution_time: Optional[int] = Field(
|
max_execution_time: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum execution time for an agent to execute a task",
|
description="Maximum execution time for an agent to execute a task",
|
||||||
)
|
)
|
||||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||||
step_callback: Optional[Any] = Field(
|
step_callback: Any | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Callback to be executed after each step of the agent execution.",
|
description="Callback to be executed after each step of the agent execution.",
|
||||||
)
|
)
|
||||||
use_system_prompt: Optional[bool] = Field(
|
use_system_prompt: bool | None = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Use system prompt for the agent.",
|
description="Use system prompt for the agent.",
|
||||||
)
|
)
|
||||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
system_template: Optional[str] = Field(
|
system_template: str | None = Field(
|
||||||
default=None, description="System format for the agent."
|
default=None, description="System format for the agent."
|
||||||
)
|
)
|
||||||
prompt_template: Optional[str] = Field(
|
prompt_template: str | None = Field(
|
||||||
default=None, description="Prompt format for the agent."
|
default=None, description="Prompt format for the agent."
|
||||||
)
|
)
|
||||||
response_template: Optional[str] = Field(
|
response_template: str | None = Field(
|
||||||
default=None, description="Response format for the agent."
|
default=None, description="Response format for the agent."
|
||||||
)
|
)
|
||||||
allow_code_execution: Optional[bool] = Field(
|
allow_code_execution: bool | None = Field(
|
||||||
default=False, description="Enable code execution for the agent."
|
default=False, description="Enable code execution for the agent."
|
||||||
)
|
)
|
||||||
respect_context_window: bool = Field(
|
respect_context_window: bool = Field(
|
||||||
@@ -147,31 +140,31 @@ class Agent(BaseAgent):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||||
)
|
)
|
||||||
max_reasoning_attempts: Optional[int] = Field(
|
max_reasoning_attempts: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||||
)
|
)
|
||||||
embedder: Optional[Dict[str, Any]] = Field(
|
embedder: dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Embedder configuration for the agent.",
|
description="Embedder configuration for the agent.",
|
||||||
)
|
)
|
||||||
agent_knowledge_context: Optional[str] = Field(
|
agent_knowledge_context: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge context for the agent.",
|
description="Knowledge context for the agent.",
|
||||||
)
|
)
|
||||||
crew_knowledge_context: Optional[str] = Field(
|
crew_knowledge_context: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge context for the crew.",
|
description="Knowledge context for the crew.",
|
||||||
)
|
)
|
||||||
knowledge_search_query: Optional[str] = Field(
|
knowledge_search_query: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||||
)
|
)
|
||||||
from_repository: Optional[str] = Field(
|
from_repository: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Agent's role to be used from your repository.",
|
description="The Agent's role to be used from your repository.",
|
||||||
)
|
)
|
||||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Function or string description of a guardrail to validate agent output",
|
description="Function or string description of a guardrail to validate agent output",
|
||||||
)
|
)
|
||||||
@@ -180,7 +173,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def validate_from_repository(cls, v):
|
def validate_from_repository(self, v):
|
||||||
if v is not None and (from_repository := v.get("from_repository")):
|
if v is not None and (from_repository := v.get("from_repository")):
|
||||||
return load_agent_from_repository(from_repository) | v
|
return load_agent_from_repository(from_repository) | v
|
||||||
return v
|
return v
|
||||||
@@ -208,7 +201,7 @@ class Agent(BaseAgent):
|
|||||||
self.cache_handler = CacheHandler()
|
self.cache_handler = CacheHandler()
|
||||||
self.set_cache_handler(self.cache_handler)
|
self.set_cache_handler(self.cache_handler)
|
||||||
|
|
||||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||||
try:
|
try:
|
||||||
if self.embedder is None and crew_embedder:
|
if self.embedder is None and crew_embedder:
|
||||||
self.embedder = crew_embedder
|
self.embedder = crew_embedder
|
||||||
@@ -224,7 +217,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
self.knowledge.add_sources()
|
self.knowledge.add_sources()
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||||
|
|
||||||
def _is_any_available_memory(self) -> bool:
|
def _is_any_available_memory(self) -> bool:
|
||||||
"""Check if any memory is available."""
|
"""Check if any memory is available."""
|
||||||
@@ -244,8 +237,8 @@ class Agent(BaseAgent):
|
|||||||
def execute_task(
|
def execute_task(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
context: Optional[str] = None,
|
context: str | None = None,
|
||||||
tools: Optional[List[BaseTool]] = None,
|
tools: list[BaseTool] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a task with the agent.
|
"""Execute a task with the agent.
|
||||||
|
|
||||||
@@ -278,11 +271,9 @@ class Agent(BaseAgent):
|
|||||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(self, "_logger"):
|
if hasattr(self, "_logger"):
|
||||||
self._logger.log(
|
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||||
"error", f"Error during reasoning process: {str(e)}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(f"Error during reasoning process: {str(e)}")
|
print(f"Error during reasoning process: {e!s}")
|
||||||
|
|
||||||
self._inject_date_to_task(task)
|
self._inject_date_to_task(task)
|
||||||
|
|
||||||
@@ -525,14 +516,14 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return future.result(timeout=timeout)
|
return future.result(timeout=timeout)
|
||||||
except concurrent.futures.TimeoutError:
|
except concurrent.futures.TimeoutError as e:
|
||||||
future.cancel()
|
future.cancel()
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||||
)
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
future.cancel()
|
future.cancel()
|
||||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
raise RuntimeError(f"Task execution failed: {e!s}") from e
|
||||||
|
|
||||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||||
"""Execute a task without a timeout.
|
"""Execute a task without a timeout.
|
||||||
@@ -554,14 +545,14 @@ class Agent(BaseAgent):
|
|||||||
)["output"]
|
)["output"]
|
||||||
|
|
||||||
def create_agent_executor(
|
def create_agent_executor(
|
||||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
self, tools: list[BaseTool] | None = None, task=None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an agent executor for the agent.
|
"""Create an agent executor for the agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the CrewAgentExecutor class.
|
An instance of the CrewAgentExecutor class.
|
||||||
"""
|
"""
|
||||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||||
parsed_tools = parse_tools(raw_tools)
|
parsed_tools = parse_tools(raw_tools)
|
||||||
|
|
||||||
prompt = Prompts(
|
prompt = Prompts(
|
||||||
@@ -603,10 +594,9 @@ class Agent(BaseAgent):
|
|||||||
callbacks=[TokenCalcHandler(self._token_process)],
|
callbacks=[TokenCalcHandler(self._token_process)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||||
agent_tools = AgentTools(agents=agents)
|
agent_tools = AgentTools(agents=agents)
|
||||||
tools = agent_tools.tools()
|
return agent_tools.tools()
|
||||||
return tools
|
|
||||||
|
|
||||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||||
@@ -654,7 +644,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
return task_prompt
|
return task_prompt
|
||||||
|
|
||||||
def _render_text_description(self, tools: List[Any]) -> str:
|
def _render_text_description(self, tools: list[Any]) -> str:
|
||||||
"""Render the tool name and description in plain text.
|
"""Render the tool name and description in plain text.
|
||||||
|
|
||||||
Output will be in the format of:
|
Output will be in the format of:
|
||||||
@@ -664,15 +654,13 @@ class Agent(BaseAgent):
|
|||||||
search: This tool is used for search
|
search: This tool is used for search
|
||||||
calculator: This tool is used for math
|
calculator: This tool is used for math
|
||||||
"""
|
"""
|
||||||
description = "\n".join(
|
return "\n".join(
|
||||||
[
|
[
|
||||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||||
for tool in tools
|
for tool in tools
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return description
|
|
||||||
|
|
||||||
def _inject_date_to_task(self, task):
|
def _inject_date_to_task(self, task):
|
||||||
"""Inject the current date into the task description if inject_date is enabled."""
|
"""Inject the current date into the task description if inject_date is enabled."""
|
||||||
if self.inject_date:
|
if self.inject_date:
|
||||||
@@ -700,9 +688,9 @@ class Agent(BaseAgent):
|
|||||||
task.description += f"\n\nCurrent Date: {current_date}"
|
task.description += f"\n\nCurrent Date: {current_date}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(self, "_logger"):
|
if hasattr(self, "_logger"):
|
||||||
self._logger.log("warning", f"Failed to inject date: {str(e)}")
|
self._logger.log("warning", f"Failed to inject date: {e!s}")
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Failed to inject date: {str(e)}")
|
print(f"Warning: Failed to inject date: {e!s}")
|
||||||
|
|
||||||
def _validate_docker_installation(self) -> None:
|
def _validate_docker_installation(self) -> None:
|
||||||
"""Check if Docker is installed and running."""
|
"""Check if Docker is installed and running."""
|
||||||
@@ -718,10 +706,10 @@ class Agent(BaseAgent):
|
|||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
)
|
)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||||
@@ -796,8 +784,8 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
def kickoff(
|
def kickoff(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
response_format: Optional[Type[Any]] = None,
|
response_format: type[Any] | None = None,
|
||||||
) -> LiteAgentOutput:
|
) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent with the given messages using a LiteAgent instance.
|
Execute the agent with the given messages using a LiteAgent instance.
|
||||||
@@ -836,8 +824,8 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
async def kickoff_async(
|
async def kickoff_async(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
response_format: Optional[Type[Any]] = None,
|
response_format: type[Any] | None = None,
|
||||||
) -> LiteAgentOutput:
|
) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
|
from crewai.agents.parser import AgentAction, AgentFinish, OutputParserException, parse
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
|
||||||
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
|
__all__ = [
|
||||||
|
"AgentAction",
|
||||||
|
"AgentFinish",
|
||||||
|
"CacheHandler",
|
||||||
|
"OutputParserException",
|
||||||
|
"ToolsHandler",
|
||||||
|
"parse",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import PrivateAttr
|
from pydantic import ConfigDict, PrivateAttr
|
||||||
|
|
||||||
from crewai.agent import BaseAgent
|
from crewai.agent import BaseAgent
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
@@ -16,22 +16,21 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
adapted_structured_output: bool = False
|
adapted_structured_output: bool = False
|
||||||
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
|
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any):
|
||||||
super().__init__(adapted_agent=True, **kwargs)
|
super().__init__(adapted_agent=True, **kwargs)
|
||||||
self._agent_config = agent_config
|
self._agent_config = agent_config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||||
"""Configure and adapt tools for the specific agent implementation.
|
"""Configure and adapt tools for the specific agent implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tools: Optional list of BaseTool instances to be configured
|
tools: Optional list of BaseTool instances to be configured
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def configure_structured_output(self, structured_output: Any) -> None:
|
def configure_structured_output(self, structured_output: Any) -> None:
|
||||||
"""Configure the structured output for the specific agent implementation.
|
"""Configure the structured output for the specific agent implementation.
|
||||||
@@ -39,4 +38,3 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
|||||||
Args:
|
Args:
|
||||||
structured_output: The structured output to be configured
|
structured_output: The structured output to be configured
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
@@ -12,23 +12,22 @@ class BaseToolAdapter(ABC):
|
|||||||
different frameworks and platforms.
|
different frameworks and platforms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
original_tools: List[BaseTool]
|
original_tools: list[BaseTool]
|
||||||
converted_tools: List[Any]
|
converted_tools: list[Any]
|
||||||
|
|
||||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
def __init__(self, tools: list[BaseTool] | None = None):
|
||||||
self.original_tools = tools or []
|
self.original_tools = tools or []
|
||||||
self.converted_tools = []
|
self.converted_tools = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||||
"""Configure and convert tools for the specific implementation.
|
"""Configure and convert tools for the specific implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tools: List of BaseTool instances to be configured and converted
|
tools: List of BaseTool instances to be configured and converted
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def tools(self) -> List[Any]:
|
def tools(self) -> list[Any]:
|
||||||
"""Return all converted tools."""
|
"""Return all converted tools."""
|
||||||
return self.converted_tools
|
return self.converted_tools
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
@@ -25,7 +26,6 @@ from crewai.security.security_config import SecurityConfig
|
|||||||
from crewai.tools.base_tool import BaseTool, Tool
|
from crewai.tools.base_tool import BaseTool, Tool
|
||||||
from crewai.utilities import I18N, Logger, RPMController
|
from crewai.utilities import I18N, Logger, RPMController
|
||||||
from crewai.utilities.config import process_config
|
from crewai.utilities.config import process_config
|
||||||
from crewai.utilities.converter import Converter
|
|
||||||
from crewai.utilities.string_utils import interpolate_only
|
from crewai.utilities.string_utils import interpolate_only
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseAgent")
|
T = TypeVar("T", bound="BaseAgent")
|
||||||
@@ -81,17 +81,17 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
_original_role: str | None = PrivateAttr(default=None)
|
||||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
_original_goal: str | None = PrivateAttr(default=None)
|
||||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
_original_backstory: str | None = PrivateAttr(default=None)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: str = Field(description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
config: Optional[Dict[str, Any]] = Field(
|
config: dict[str, Any] | None = Field(
|
||||||
description="Configuration for the agent", default=None, exclude=True
|
description="Configuration for the agent", default=None, exclude=True
|
||||||
)
|
)
|
||||||
cache: bool = Field(
|
cache: bool = Field(
|
||||||
@@ -100,7 +100,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=False, description="Verbose mode for the Agent Execution"
|
default=False, description="Verbose mode for the Agent Execution"
|
||||||
)
|
)
|
||||||
max_rpm: Optional[int] = Field(
|
max_rpm: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum number of requests per minute for the agent execution to be respected.",
|
description="Maximum number of requests per minute for the agent execution to be respected.",
|
||||||
)
|
)
|
||||||
@@ -108,7 +108,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Enable agent to delegate and ask questions among each other.",
|
description="Enable agent to delegate and ask questions among each other.",
|
||||||
)
|
)
|
||||||
tools: Optional[List[BaseTool]] = Field(
|
tools: list[BaseTool] | None = Field(
|
||||||
default_factory=list, description="Tools at agents' disposal"
|
default_factory=list, description="Tools at agents' disposal"
|
||||||
)
|
)
|
||||||
max_iter: int = Field(
|
max_iter: int = Field(
|
||||||
@@ -122,27 +122,27 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
)
|
)
|
||||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
cache_handler: InstanceOf[CacheHandler] | None = Field(
|
||||||
default=None, description="An instance of the CacheHandler class."
|
default=None, description="An instance of the CacheHandler class."
|
||||||
)
|
)
|
||||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||||
default_factory=ToolsHandler,
|
default_factory=ToolsHandler,
|
||||||
description="An instance of the ToolsHandler class.",
|
description="An instance of the ToolsHandler class.",
|
||||||
)
|
)
|
||||||
tools_results: List[Dict[str, Any]] = Field(
|
tools_results: list[dict[str, Any]] = Field(
|
||||||
default=[], description="Results of the tools used by the agent."
|
default=[], description="Results of the tools used by the agent."
|
||||||
)
|
)
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: int | None = Field(
|
||||||
default=None, description="Maximum number of tokens for the agent's execution."
|
default=None, description="Maximum number of tokens for the agent's execution."
|
||||||
)
|
)
|
||||||
knowledge: Optional[Knowledge] = Field(
|
knowledge: Knowledge | None = Field(
|
||||||
default=None, description="Knowledge for the agent."
|
default=None, description="Knowledge for the agent."
|
||||||
)
|
)
|
||||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the agent.",
|
description="Knowledge sources for the agent.",
|
||||||
)
|
)
|
||||||
knowledge_storage: Optional[Any] = Field(
|
knowledge_storage: Any | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Custom knowledge storage for the agent.",
|
description="Custom knowledge storage for the agent.",
|
||||||
)
|
)
|
||||||
@@ -150,13 +150,13 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
default_factory=SecurityConfig,
|
default_factory=SecurityConfig,
|
||||||
description="Security configuration for the agent, including fingerprinting.",
|
description="Security configuration for the agent, including fingerprinting.",
|
||||||
)
|
)
|
||||||
callbacks: List[Callable] = Field(
|
callbacks: list[Callable] = Field(
|
||||||
default=[], description="Callbacks to be used for the agent"
|
default=[], description="Callbacks to be used for the agent"
|
||||||
)
|
)
|
||||||
adapted_agent: bool = Field(
|
adapted_agent: bool = Field(
|
||||||
default=False, description="Whether the agent is adapted"
|
default=False, description="Whether the agent is adapted"
|
||||||
)
|
)
|
||||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
knowledge_config: KnowledgeConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge configuration for the agent such as limits and threshold",
|
description="Knowledge configuration for the agent such as limits and threshold",
|
||||||
)
|
)
|
||||||
@@ -168,7 +168,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
|
|
||||||
@field_validator("tools")
|
@field_validator("tools")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||||
"""Validate and process the tools provided to the agent.
|
"""Validate and process the tools provided to the agent.
|
||||||
|
|
||||||
This method ensures that each tool is either an instance of BaseTool
|
This method ensures that each tool is either an instance of BaseTool
|
||||||
@@ -221,7 +221,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||||
if v:
|
if v:
|
||||||
raise PydanticCustomError(
|
raise PydanticCustomError(
|
||||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||||
@@ -252,8 +252,8 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
def execute_task(
|
def execute_task(
|
||||||
self,
|
self,
|
||||||
task: Any,
|
task: Any,
|
||||||
context: Optional[str] = None,
|
context: str | None = None,
|
||||||
tools: Optional[List[BaseTool]] = None,
|
tools: list[BaseTool] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -262,9 +262,8 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||||
"""Set the task tools that init BaseAgenTools class."""
|
"""Set the task tools that init BaseAgenTools class."""
|
||||||
pass
|
|
||||||
|
|
||||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||||
"""Create a deep copy of the Agent."""
|
"""Create a deep copy of the Agent."""
|
||||||
@@ -309,7 +308,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
|
|
||||||
copied_data = self.model_dump(exclude=exclude)
|
copied_data = self.model_dump(exclude=exclude)
|
||||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||||
copied_agent = type(self)(
|
return type(self)(
|
||||||
**copied_data,
|
**copied_data,
|
||||||
llm=existing_llm,
|
llm=existing_llm,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
@@ -318,9 +317,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
knowledge_storage=copied_knowledge_storage,
|
knowledge_storage=copied_knowledge_storage,
|
||||||
)
|
)
|
||||||
|
|
||||||
return copied_agent
|
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||||
|
|
||||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
|
||||||
"""Interpolate inputs into the agent description and backstory."""
|
"""Interpolate inputs into the agent description and backstory."""
|
||||||
if self._original_role is None:
|
if self._original_role is None:
|
||||||
self._original_role = self.role
|
self._original_role = self.role
|
||||||
@@ -362,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = rpm_controller
|
self._rpm_controller = rpm_controller
|
||||||
self.create_agent_executor()
|
self.create_agent_executor()
|
||||||
|
|
||||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from crewai.events.event_listener import event_listener
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
from crewai.utilities import I18N
|
from crewai.utilities import I18N
|
||||||
from crewai.utilities.converter import ConverterError
|
from crewai.utilities.converter import ConverterError
|
||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.events.event_listener import event_listener
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -21,7 +21,7 @@ class CrewAgentExecutorMixin:
|
|||||||
task: "Task"
|
task: "Task"
|
||||||
iterations: int
|
iterations: int
|
||||||
max_iter: int
|
max_iter: int
|
||||||
messages: List[Dict[str, str]]
|
messages: list[dict[str, str]]
|
||||||
_i18n: I18N
|
_i18n: I18N
|
||||||
_printer: Printer = Printer()
|
_printer: Printer = Printer()
|
||||||
|
|
||||||
@@ -46,7 +46,6 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to add to short term memory: {e}")
|
print(f"Failed to add to short term memory: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
def _create_external_memory(self, output) -> None:
|
def _create_external_memory(self, output) -> None:
|
||||||
"""Create and save a external-term memory item if conditions are met."""
|
"""Create and save a external-term memory item if conditions are met."""
|
||||||
@@ -67,7 +66,6 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to add to external memory: {e}")
|
print(f"Failed to add to external memory: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
def _create_long_term_memory(self, output) -> None:
|
def _create_long_term_memory(self, output) -> None:
|
||||||
"""Create and save long-term and entity memory items based on evaluation."""
|
"""Create and save long-term and entity memory items based on evaluation."""
|
||||||
@@ -113,10 +111,8 @@ class CrewAgentExecutorMixin:
|
|||||||
self.crew._entity_memory.save(entity_memories)
|
self.crew._entity_memory.save(entity_memories)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
print(f"Missing attributes for long term memory: {e}")
|
print(f"Missing attributes for long term memory: {e}")
|
||||||
pass
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to add to long term memory: {e}")
|
print(f"Failed to add to long term memory: {e}")
|
||||||
pass
|
|
||||||
elif (
|
elif (
|
||||||
self.crew
|
self.crew
|
||||||
and self.crew._long_term_memory
|
and self.crew._long_term_memory
|
||||||
|
|||||||
@@ -251,9 +251,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
handle_unknown_error(self._printer, e)
|
||||||
handle_unknown_error(self._printer, e)
|
raise e
|
||||||
raise e
|
|
||||||
finally:
|
finally:
|
||||||
self.iterations += 1
|
self.iterations += 1
|
||||||
|
|
||||||
@@ -324,9 +323,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.agent,
|
self.agent,
|
||||||
AgentLogsStartedEvent(
|
AgentLogsStartedEvent(
|
||||||
agent_role=self.agent.role,
|
agent_role=self.agent.role,
|
||||||
task_description=(
|
task_description=(self.task.description if self.task else "Not Found"),
|
||||||
getattr(self.task, "description") if self.task else "Not Found"
|
|
||||||
),
|
|
||||||
verbose=self.agent.verbose
|
verbose=self.agent.verbose
|
||||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||||
),
|
),
|
||||||
@@ -415,8 +412,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
"""
|
"""
|
||||||
prompt = prompt.replace("{input}", inputs["input"])
|
prompt = prompt.replace("{input}", inputs["input"])
|
||||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
return prompt.replace("{tools}", inputs["tools"])
|
||||||
return prompt
|
|
||||||
|
|
||||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||||
"""Process human feedback.
|
"""Process human feedback.
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from dataclasses import dataclass
|
|||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from crewai.agents.constants import (
|
from crewai.agents.constants import (
|
||||||
|
ACTION_INPUT_ONLY_REGEX,
|
||||||
ACTION_INPUT_REGEX,
|
ACTION_INPUT_REGEX,
|
||||||
ACTION_REGEX,
|
ACTION_REGEX,
|
||||||
ACTION_INPUT_ONLY_REGEX,
|
|
||||||
FINAL_ANSWER_ACTION,
|
FINAL_ANSWER_ACTION,
|
||||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||||
@@ -104,7 +104,7 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
|||||||
final_answer = final_answer[:-3].rstrip()
|
final_answer = final_answer[:-3].rstrip()
|
||||||
return AgentFinish(thought=thought, output=final_answer, text=text)
|
return AgentFinish(thought=thought, output=final_answer, text=text)
|
||||||
|
|
||||||
elif action_match:
|
if action_match:
|
||||||
action = action_match.group(1)
|
action = action_match.group(1)
|
||||||
clean_action = _clean_action(action)
|
clean_action = _clean_action(action)
|
||||||
|
|
||||||
@@ -121,16 +121,15 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
|||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
|
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
|
||||||
)
|
)
|
||||||
elif not ACTION_INPUT_ONLY_REGEX.search(text):
|
if not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||||
)
|
)
|
||||||
else:
|
err_format = _I18N.slice("format_without_tools")
|
||||||
err_format = _I18N.slice("format_without_tools")
|
error = f"{err_format}"
|
||||||
error = f"{err_format}"
|
raise OutputParserException(
|
||||||
raise OutputParserException(
|
error,
|
||||||
error,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_thought(text: str) -> str:
|
def _extract_thought(text: str) -> str:
|
||||||
@@ -149,8 +148,7 @@ def _extract_thought(text: str) -> str:
|
|||||||
return ""
|
return ""
|
||||||
thought = text[:thought_index].strip()
|
thought = text[:thought_index].strip()
|
||||||
# Remove any triple backticks from the thought string
|
# Remove any triple backticks from the thought string
|
||||||
thought = thought.replace("```", "").strip()
|
return thought.replace("```", "").strip()
|
||||||
return thought
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_action(text: str) -> str:
|
def _clean_action(text: str) -> str:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Tools handler for managing tool execution and caching."""
|
"""Tools handler for managing tool execution and caching."""
|
||||||
|
|
||||||
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
|
||||||
|
|
||||||
|
|
||||||
class ToolsHandler:
|
class ToolsHandler:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class Auth0Provider(BaseProvider):
|
class Auth0Provider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self._get_domain()}/oauth/device/code"
|
return f"https://{self._get_domain()}/oauth/device/code"
|
||||||
|
|||||||
@@ -1,30 +1,26 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from crewai.cli.authentication.main import Oauth2Settings
|
from crewai.cli.authentication.main import Oauth2Settings
|
||||||
|
|
||||||
|
|
||||||
class BaseProvider(ABC):
|
class BaseProvider(ABC):
|
||||||
def __init__(self, settings: Oauth2Settings):
|
def __init__(self, settings: Oauth2Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_token_url(self) -> str:
|
def get_token_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_jwks_url(self) -> str:
|
def get_jwks_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_issuer(self) -> str:
|
def get_issuer(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_audience(self) -> str:
|
def get_audience(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str: ...
|
||||||
...
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class OktaProvider(BaseProvider):
|
class OktaProvider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class WorkosProvider(BaseProvider):
|
class WorkosProvider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||||
@@ -17,9 +18,11 @@ class WorkosProvider(BaseProvider):
|
|||||||
return self.settings.audience or ""
|
return self.settings.audience or ""
|
||||||
|
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str:
|
||||||
assert self.settings.client_id is not None, "Client ID is required"
|
if self.settings.client_id is None:
|
||||||
|
raise RuntimeError("Client ID is required")
|
||||||
return self.settings.client_id
|
return self.settings.client_id
|
||||||
|
|
||||||
def _get_domain(self) -> str:
|
def _get_domain(self) -> str:
|
||||||
assert self.settings.domain is not None, "Domain is required"
|
if self.settings.domain is None:
|
||||||
|
raise RuntimeError("Domain is required")
|
||||||
return self.settings.domain
|
return self.settings.domain
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ def validate_jwt_token(
|
|||||||
missing required claims).
|
missing required claims).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decoded_token = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
jwk_client = PyJWKClient(jwks_url)
|
jwk_client = PyJWKClient(jwks_url)
|
||||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||||
@@ -26,7 +24,7 @@ def validate_jwt_token(
|
|||||||
_unverified_decoded_token = jwt.decode(
|
_unverified_decoded_token = jwt.decode(
|
||||||
jwt_token, options={"verify_signature": False}
|
jwt_token, options={"verify_signature": False}
|
||||||
)
|
)
|
||||||
decoded_token = jwt.decode(
|
return jwt.decode(
|
||||||
jwt_token,
|
jwt_token,
|
||||||
signing_key.key,
|
signing_key.key,
|
||||||
algorithms=["RS256"],
|
algorithms=["RS256"],
|
||||||
@@ -40,7 +38,6 @@ def validate_jwt_token(
|
|||||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return decoded_token
|
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
raise Exception("Token has expired.")
|
raise Exception("Token has expired.")
|
||||||
@@ -55,8 +52,8 @@ def validate_jwt_token(
|
|||||||
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
||||||
)
|
)
|
||||||
except jwt.MissingRequiredClaimError as e:
|
except jwt.MissingRequiredClaimError as e:
|
||||||
raise Exception(f"Token is missing required claims: {str(e)}")
|
raise Exception(f"Token is missing required claims: {e!s}")
|
||||||
except jwt.exceptions.PyJWKClientError as e:
|
except jwt.exceptions.PyJWKClientError as e:
|
||||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
raise Exception(f"JWKS or key processing error: {e!s}")
|
||||||
except jwt.InvalidTokenError as e:
|
except jwt.InvalidTokenError as e:
|
||||||
raise Exception(f"Invalid token: {str(e)}")
|
raise Exception(f"Invalid token: {e!s}")
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from importlib.metadata import version as get_version
|
from importlib.metadata import version as get_version
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from crewai.cli.config import Settings
|
|
||||||
from crewai.cli.settings.main import SettingsCommand
|
|
||||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||||
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.create_crew import create_crew
|
from crewai.cli.create_crew import create_crew
|
||||||
from crewai.cli.create_flow import create_flow
|
from crewai.cli.create_flow import create_flow
|
||||||
from crewai.cli.crew_chat import run_chat
|
from crewai.cli.crew_chat import run_chat
|
||||||
|
from crewai.cli.settings.main import SettingsCommand
|
||||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||||
KickoffTaskOutputsSQLiteStorage,
|
KickoffTaskOutputsSQLiteStorage,
|
||||||
)
|
)
|
||||||
@@ -237,13 +237,11 @@ def login():
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def deploy():
|
def deploy():
|
||||||
"""Deploy the Crew CLI group."""
|
"""Deploy the Crew CLI group."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@crewai.group()
|
@crewai.group()
|
||||||
def tool():
|
def tool():
|
||||||
"""Tool Repository related commands."""
|
"""Tool Repository related commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@deploy.command(name="create")
|
@deploy.command(name="create")
|
||||||
@@ -263,7 +261,7 @@ def deploy_list():
|
|||||||
|
|
||||||
@deploy.command(name="push")
|
@deploy.command(name="push")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_push(uuid: Optional[str]):
|
def deploy_push(uuid: str | None):
|
||||||
"""Deploy the Crew."""
|
"""Deploy the Crew."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.deploy(uuid=uuid)
|
deploy_cmd.deploy(uuid=uuid)
|
||||||
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="status")
|
@deploy.command(name="status")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deply_status(uuid: Optional[str]):
|
def deply_status(uuid: str | None):
|
||||||
"""Get the status of a deployment."""
|
"""Get the status of a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_status(uuid=uuid)
|
deploy_cmd.get_crew_status(uuid=uuid)
|
||||||
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="logs")
|
@deploy.command(name="logs")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_logs(uuid: Optional[str]):
|
def deploy_logs(uuid: str | None):
|
||||||
"""Get the logs of a deployment."""
|
"""Get the logs of a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||||
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="remove")
|
@deploy.command(name="remove")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_remove(uuid: Optional[str]):
|
def deploy_remove(uuid: str | None):
|
||||||
"""Remove a deployment."""
|
"""Remove a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.remove_crew(uuid=uuid)
|
deploy_cmd.remove_crew(uuid=uuid)
|
||||||
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def flow():
|
def flow():
|
||||||
"""Flow related commands."""
|
"""Flow related commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@flow.command(name="kickoff")
|
@flow.command(name="kickoff")
|
||||||
@@ -359,7 +356,7 @@ def chat():
|
|||||||
and using the Chat LLM to generate responses.
|
and using the Chat LLM to generate responses.
|
||||||
"""
|
"""
|
||||||
click.secho(
|
click.secho(
|
||||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
run_chat()
|
run_chat()
|
||||||
@@ -368,7 +365,6 @@ def chat():
|
|||||||
@crewai.group(invoke_without_command=True)
|
@crewai.group(invoke_without_command=True)
|
||||||
def org():
|
def org():
|
||||||
"""Organization management commands."""
|
"""Organization management commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@org.command("list")
|
@org.command("list")
|
||||||
@@ -396,7 +392,6 @@ def current():
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def enterprise():
|
def enterprise():
|
||||||
"""Enterprise Configuration commands."""
|
"""Enterprise Configuration commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@enterprise.command("configure")
|
@enterprise.command("configure")
|
||||||
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def config():
|
def config():
|
||||||
"""CLI Configuration commands."""
|
"""CLI Configuration commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@config.command("list")
|
@config.command("list")
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.cli.constants import (
|
from crewai.cli.constants import (
|
||||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||||
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||||
|
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||||
)
|
)
|
||||||
from crewai.cli.shared.token_manager import TokenManager
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
|
||||||
@@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseModel):
|
class Settings(BaseModel):
|
||||||
enterprise_base_url: Optional[str] = Field(
|
enterprise_base_url: str | None = Field(
|
||||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||||
description="Base URL of the CrewAI Enterprise instance",
|
description="Base URL of the CrewAI Enterprise instance",
|
||||||
)
|
)
|
||||||
tool_repository_username: Optional[str] = Field(
|
tool_repository_username: str | None = Field(
|
||||||
None, description="Username for interacting with the Tool Repository"
|
None, description="Username for interacting with the Tool Repository"
|
||||||
)
|
)
|
||||||
tool_repository_password: Optional[str] = Field(
|
tool_repository_password: str | None = Field(
|
||||||
None, description="Password for interacting with the Tool Repository"
|
None, description="Password for interacting with the Tool Repository"
|
||||||
)
|
)
|
||||||
org_name: Optional[str] = Field(
|
org_name: str | None = Field(
|
||||||
None, description="Name of the currently active organization"
|
None, description="Name of the currently active organization"
|
||||||
)
|
)
|
||||||
org_uuid: Optional[str] = Field(
|
org_uuid: str | None = Field(
|
||||||
None, description="UUID of the currently active organization"
|
None, description="UUID of the currently active organization"
|
||||||
)
|
)
|
||||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||||
@@ -79,7 +78,7 @@ class Settings(BaseModel):
|
|||||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth2_audience: Optional[str] = Field(
|
oauth2_audience: str | None = Field(
|
||||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
|||||||
def create_folder_structure(name, parent_folder=None):
|
def create_folder_structure(name, parent_folder=None):
|
||||||
import keyword
|
import keyword
|
||||||
import re
|
import re
|
||||||
|
|
||||||
name = name.rstrip('/')
|
name = name.rstrip("/")
|
||||||
|
|
||||||
if not name.strip():
|
if not name.strip():
|
||||||
raise ValueError("Project name cannot be empty or contain only whitespace")
|
raise ValueError("Project name cannot be empty or contain only whitespace")
|
||||||
|
|
||||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||||
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name)
|
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
|
||||||
|
|
||||||
# Check if the name starts with invalid characters or is primarily invalid
|
# Check if the name starts with invalid characters or is primarily invalid
|
||||||
if re.match(r'^[^a-zA-Z0-9_-]+', name):
|
if re.match(r"^[^a-zA-Z0-9_-]+", name):
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_name:
|
if not folder_name:
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||||
|
)
|
||||||
|
|
||||||
if folder_name[0].isdigit():
|
if folder_name[0].isdigit():
|
||||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
|
||||||
|
)
|
||||||
|
|
||||||
if keyword.iskeyword(folder_name):
|
if keyword.iskeyword(folder_name):
|
||||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_name.isidentifier():
|
if not folder_name.isidentifier():
|
||||||
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||||
|
|
||||||
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name)
|
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||||
|
|
||||||
if not class_name:
|
if not class_name:
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python class name"
|
||||||
|
)
|
||||||
|
|
||||||
if class_name[0].isdigit():
|
if class_name[0].isdigit():
|
||||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if the original name (before title casing) is a keyword
|
# Check if the original name (before title casing) is a keyword
|
||||||
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower())
|
original_name_clean = re.sub(
|
||||||
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'):
|
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
|
||||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword")
|
)
|
||||||
|
if (
|
||||||
|
keyword.iskeyword(original_name_clean)
|
||||||
|
or keyword.iskeyword(class_name)
|
||||||
|
or class_name in ("True", "False", "None")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
|
||||||
|
)
|
||||||
|
|
||||||
if not class_name.isidentifier():
|
if not class_name.isidentifier():
|
||||||
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
if parent_folder:
|
if parent_folder:
|
||||||
folder_path = Path(parent_folder) / folder_name
|
folder_path = Path(parent_folder) / folder_name
|
||||||
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if the selected provider has predefined models
|
# Check if the selected provider has predefined models
|
||||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
if MODELS.get(selected_provider):
|
||||||
while True:
|
while True:
|
||||||
selected_model = select_model(selected_provider, provider_models)
|
selected_model = select_model(selected_provider, provider_models)
|
||||||
if selected_model is None: # User typed 'q'
|
if selected_model is None: # User typed 'q'
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomli
|
import tomli
|
||||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||||
"""Initializes the chat LLM and handles exceptions."""
|
"""Initializes the chat LLM and handles exceptions."""
|
||||||
try:
|
try:
|
||||||
return create_llm(crew.chat_llm)
|
return create_llm(crew.chat_llm)
|
||||||
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||||
"""Creates a wrapper function for running the crew tool with messages."""
|
"""Creates a wrapper function for running the crew tool with messages."""
|
||||||
|
|
||||||
def run_crew_tool_with_messages(**kwargs):
|
def run_crew_tool_with_messages(**kwargs):
|
||||||
@@ -221,9 +221,9 @@ def get_user_input() -> str:
|
|||||||
def handle_user_input(
|
def handle_user_input(
|
||||||
user_input: str,
|
user_input: str,
|
||||||
chat_llm: LLM,
|
chat_llm: LLM,
|
||||||
messages: List[Dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
crew_tool_schema: Dict[str, Any],
|
crew_tool_schema: dict[str, Any],
|
||||||
available_functions: Dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if user_input.strip().lower() == "exit":
|
if user_input.strip().lower() == "exit":
|
||||||
click.echo("Exiting chat. Goodbye!")
|
click.echo("Exiting chat. Goodbye!")
|
||||||
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||||
|
|
||||||
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
crew_output = crew.kickoff(inputs=kwargs)
|
crew_output = crew.kickoff(inputs=kwargs)
|
||||||
|
|
||||||
# Convert CrewOutput to a string to send back to the user
|
# Convert CrewOutput to a string to send back to the user
|
||||||
result = str(crew_output)
|
return str(crew_output)
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Exit the chat and show the error message
|
# Exit the chat and show the error message
|
||||||
click.secho("An error occurred while running the crew:", fg="red")
|
click.secho("An error occurred while running the crew:", fg="red")
|
||||||
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
def load_crew_and_name() -> tuple[Crew, str]:
|
||||||
"""
|
"""
|
||||||
Loads the crew by importing the crew class from the user's project.
|
Loads the crew by importing the crew class from the user's project.
|
||||||
|
|
||||||
@@ -395,7 +394,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Extracts placeholders from the crew's tasks and agents.
|
Extracts placeholders from the crew's tasks and agents.
|
||||||
|
|
||||||
@@ -406,7 +405,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
|||||||
Set[str]: A set of placeholder names.
|
Set[str]: A set of placeholder names.
|
||||||
"""
|
"""
|
||||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
required_inputs: Set[str] = set()
|
required_inputs: set[str] = set()
|
||||||
|
|
||||||
# Scan tasks
|
# Scan tasks
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
@@ -479,9 +478,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
|||||||
f"{context}"
|
f"{context}"
|
||||||
)
|
)
|
||||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
description = response.strip()
|
return response.strip()
|
||||||
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||||
@@ -531,6 +528,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
|||||||
f"{context}"
|
f"{context}"
|
||||||
)
|
)
|
||||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
crew_description = response.strip()
|
return response.strip()
|
||||||
|
|
||||||
return crew_description
|
|
||||||
|
|||||||
@@ -64,8 +64,7 @@ class Repository:
|
|||||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||||
return False
|
return False
|
||||||
else:
|
return True
|
||||||
return True
|
|
||||||
|
|
||||||
def origin_url(self) -> str | None:
|
def origin_url(self) -> str | None:
|
||||||
"""Get the Git repository's remote URL."""
|
"""Get the Git repository's remote URL."""
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def install_crew(proxy_options: list[str]) -> None:
|
|||||||
Install the crew by running the UV command to lock and install.
|
Install the crew by running the UV command to lock and install.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
command = ["uv", "sync"] + proxy_options
|
command = ["uv", "sync", *proxy_options]
|
||||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
subprocess.run(command, check=True, capture_output=False, text=True)
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.version import get_crewai_version
|
|
||||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
from crewai.cli.version import get_crewai_version
|
||||||
|
|
||||||
|
|
||||||
class PlusAPI:
|
class PlusAPI:
|
||||||
@@ -56,9 +55,9 @@ class PlusAPI:
|
|||||||
handle: str,
|
handle: str,
|
||||||
is_public: bool,
|
is_public: bool,
|
||||||
version: str,
|
version: str,
|
||||||
description: Optional[str],
|
description: str | None,
|
||||||
encoded_file: str,
|
encoded_file: str,
|
||||||
available_exports: Optional[List[str]] = None,
|
available_exports: list[str] | None = None,
|
||||||
):
|
):
|
||||||
params = {
|
params = {
|
||||||
"handle": handle,
|
"handle": handle,
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import certifi
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import certifi
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
|
|||||||
|
|
||||||
provider_models = get_provider_data()
|
provider_models = get_provider_data()
|
||||||
if not provider_models:
|
if not provider_models:
|
||||||
return
|
return None
|
||||||
click.secho(prompt_message, fg="cyan")
|
click.secho(prompt_message, fg="cyan")
|
||||||
for idx, choice in enumerate(choices, start=1):
|
for idx, choice in enumerate(choices, start=1):
|
||||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||||
@@ -67,7 +67,7 @@ def select_provider(provider_models):
|
|||||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||||
|
|
||||||
provider = select_choice(
|
provider = select_choice(
|
||||||
"Select a provider to set up:", predefined_providers + ["other"]
|
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||||
)
|
)
|
||||||
if provider is None: # User typed 'q'
|
if provider is None: # User typed 'q'
|
||||||
return None
|
return None
|
||||||
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
|
|||||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
selected_model = select_choice(
|
return select_choice(
|
||||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||||
)
|
)
|
||||||
return selected_model
|
|
||||||
|
|
||||||
|
|
||||||
def load_provider_data(cache_file, cache_expiry):
|
def load_provider_data(cache_file, cache_expiry):
|
||||||
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
|
|||||||
Returns:
|
Returns:
|
||||||
- dict or None: The fetched provider data or None if the operation fails.
|
- dict or None: The fetched provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where()
|
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ class TokenManager:
|
|||||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||||
self.save_secure_file(self.file_path, encrypted_data)
|
self.save_secure_file(self.file_path, encrypted_data)
|
||||||
|
|
||||||
def get_token(self) -> Optional[str]:
|
def get_token(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Get the access token if it is valid and not expired.
|
Get the access token if it is valid and not expired.
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class TokenManager:
|
|||||||
# Set appropriate permissions (read/write for owner only)
|
# Set appropriate permissions (read/write for owner only)
|
||||||
os.chmod(file_path, 0o600)
|
os.chmod(file_path, 0o600)
|
||||||
|
|
||||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
def read_secure_file(self, filename: str) -> bytes | None:
|
||||||
"""
|
"""
|
||||||
Read the content of a secure file.
|
Read the content of a secure file.
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import sys
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from inspect import getmro, isclass, isfunction, ismethod
|
from inspect import getmro, isclass, isfunction, ismethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, get_type_hints
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomli
|
import tomli
|
||||||
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
|||||||
def read_toml(file_path: str = "pyproject.toml"):
|
def read_toml(file_path: str = "pyproject.toml"):
|
||||||
"""Read the content of a TOML file and return it as a dictionary."""
|
"""Read the content of a TOML file and return it as a dictionary."""
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
toml_dict = tomli.load(f)
|
return tomli.load(f)
|
||||||
return toml_dict
|
|
||||||
|
|
||||||
|
|
||||||
def parse_toml(content):
|
def parse_toml(content):
|
||||||
@@ -77,7 +76,7 @@ def get_project_description(
|
|||||||
|
|
||||||
|
|
||||||
def _get_project_attribute(
|
def _get_project_attribute(
|
||||||
pyproject_path: str, keys: List[str], require: bool
|
pyproject_path: str, keys: list[str], require: bool
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
"""Get an attribute from the pyproject.toml file."""
|
"""Get an attribute from the pyproject.toml file."""
|
||||||
attribute = None
|
attribute = None
|
||||||
@@ -96,7 +95,10 @@ def _get_project_attribute(
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
|
console.print(
|
||||||
|
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||||
console.print(
|
console.print(
|
||||||
f"Error: {pyproject_path} is not a valid TOML file."
|
f"Error: {pyproject_path} is not a valid TOML file."
|
||||||
@@ -117,7 +119,7 @@ def _get_project_attribute(
|
|||||||
return attribute
|
return attribute
|
||||||
|
|
||||||
|
|
||||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||||
return reduce(dict.__getitem__, keys, data)
|
return reduce(dict.__getitem__, keys, data)
|
||||||
|
|
||||||
|
|
||||||
@@ -296,7 +298,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
try:
|
try:
|
||||||
crew_instances.extend(fetch_crews(module_attr))
|
crew_instances.extend(fetch_crews(module_attr))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red")
|
console.print(
|
||||||
|
f"Error processing attribute {attr_name}: {e}",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If we found crew instances, break out of the loop
|
# If we found crew instances, break out of the loop
|
||||||
@@ -304,12 +309,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
break
|
break
|
||||||
|
|
||||||
except Exception as exec_error:
|
except Exception as exec_error:
|
||||||
console.print(f"Error executing module: {exec_error}", style="bold red")
|
console.print(
|
||||||
|
f"Error executing module: {exec_error}",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
|
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
if require:
|
if require:
|
||||||
console.print(
|
console.print(
|
||||||
f"Error importing crew from {crew_path}: {str(e)}",
|
f"Error importing crew from {crew_path}: {e!s}",
|
||||||
style="bold red",
|
style="bold red",
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -325,7 +333,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if require:
|
if require:
|
||||||
console.print(
|
console.print(
|
||||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
f"Unexpected error while loading crew: {e!s}", style="bold red"
|
||||||
)
|
)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
return crew_instances
|
return crew_instances
|
||||||
@@ -348,8 +356,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
|||||||
|
|
||||||
if isinstance(module_attr, Crew):
|
if isinstance(module_attr, Crew):
|
||||||
return module_attr
|
return module_attr
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_crews(module_attr) -> list[Crew]:
|
def fetch_crews(module_attr) -> list[Crew]:
|
||||||
@@ -402,7 +409,7 @@ def extract_available_exports(dir_path: str = "src"):
|
|||||||
return available_exports
|
return available_exports
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
|
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
|
||||||
console.print(
|
console.print(
|
||||||
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
||||||
)
|
)
|
||||||
@@ -440,7 +447,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
|
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,21 +1,23 @@
|
|||||||
import os
|
|
||||||
import contextvars
|
import contextvars
|
||||||
from typing import Optional
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
_platform_integration_token: contextvars.ContextVar[str | None] = (
|
||||||
"platform_integration_token", default=None
|
contextvars.ContextVar("platform_integration_token", default=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_platform_integration_token(integration_token: str) -> None:
|
def set_platform_integration_token(integration_token: str) -> None:
|
||||||
_platform_integration_token.set(integration_token)
|
_platform_integration_token.set(integration_token)
|
||||||
|
|
||||||
def get_platform_integration_token() -> Optional[str]:
|
|
||||||
|
def get_platform_integration_token() -> str | None:
|
||||||
token = _platform_integration_token.get()
|
token = _platform_integration_token.get()
|
||||||
if token is None:
|
if token is None:
|
||||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def platform_context(integration_token: str):
|
def platform_context(integration_token: str):
|
||||||
token = _platform_integration_token.set(integration_token)
|
token = _platform_integration_token.set(integration_token)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,10 +12,10 @@ class CrewOutput(BaseModel):
|
|||||||
"""Class that represents the result of a crew."""
|
"""Class that represents the result of a crew."""
|
||||||
|
|
||||||
raw: str = Field(description="Raw output of crew", default="")
|
raw: str = Field(description="Raw output of crew", default="")
|
||||||
pydantic: Optional[BaseModel] = Field(
|
pydantic: BaseModel | None = Field(
|
||||||
description="Pydantic output of Crew", default=None
|
description="Pydantic output of Crew", default=None
|
||||||
)
|
)
|
||||||
json_dict: Optional[Dict[str, Any]] = Field(
|
json_dict: dict[str, Any] | None = Field(
|
||||||
description="JSON dict output of Crew", default=None
|
description="JSON dict output of Crew", default=None
|
||||||
)
|
)
|
||||||
tasks_output: list[TaskOutput] = Field(
|
tasks_output: list[TaskOutput] = Field(
|
||||||
@@ -24,7 +24,7 @@ class CrewOutput(BaseModel):
|
|||||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json(self) -> Optional[str]:
|
def json(self) -> str | None:
|
||||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||||
@@ -32,7 +32,7 @@ class CrewOutput(BaseModel):
|
|||||||
|
|
||||||
return json.dumps(self.json_dict)
|
return json.dumps(self.json_dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert json_output and pydantic_output to a dictionary."""
|
"""Convert json_output and pydantic_output to a dictionary."""
|
||||||
output_dict = {}
|
output_dict = {}
|
||||||
if self.json_dict:
|
if self.json_dict:
|
||||||
@@ -44,10 +44,9 @@ class CrewOutput(BaseModel):
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if self.pydantic and hasattr(self.pydantic, key):
|
if self.pydantic and hasattr(self.pydantic, key):
|
||||||
return getattr(self.pydantic, key)
|
return getattr(self.pydantic, key)
|
||||||
elif self.json_dict and key in self.json_dict:
|
if self.json_dict and key in self.json_dict:
|
||||||
return self.json_dict[key]
|
return self.json_dict[key]
|
||||||
else:
|
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.utilities.serialization import to_serializable
|
from crewai.utilities.serialization import to_serializable
|
||||||
@@ -10,11 +11,11 @@ class BaseEvent(BaseModel):
|
|||||||
|
|
||||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
type: str
|
type: str
|
||||||
source_fingerprint: Optional[str] = None # UUID string of the source entity
|
source_fingerprint: str | None = None # UUID string of the source entity
|
||||||
source_type: Optional[str] = (
|
source_type: str | None = (
|
||||||
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
||||||
)
|
)
|
||||||
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
|
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
|
||||||
|
|
||||||
def to_json(self, exclude: set[str] | None = None):
|
def to_json(self, exclude: set[str] | None = None):
|
||||||
"""
|
"""
|
||||||
@@ -28,13 +29,13 @@ class BaseEvent(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return to_serializable(self, exclude=exclude)
|
return to_serializable(self, exclude=exclude)
|
||||||
|
|
||||||
def _set_task_params(self, data: Dict[str, Any]):
|
def _set_task_params(self, data: dict[str, Any]):
|
||||||
if "from_task" in data and (task := data["from_task"]):
|
if "from_task" in data and (task := data["from_task"]):
|
||||||
self.task_id = task.id
|
self.task_id = task.id
|
||||||
self.task_name = task.name or task.description
|
self.task_name = task.name or task.description
|
||||||
self.from_task = None
|
self.from_task = None
|
||||||
|
|
||||||
def _set_agent_params(self, data: Dict[str, Any]):
|
def _set_agent_params(self, data: dict[str, Any]):
|
||||||
task = data.get("from_task", None)
|
task = data.get("from_task", None)
|
||||||
agent = task.agent if task else data.get("from_agent", None)
|
agent = task.agent if task else data.get("from_agent", None)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
from blinker import Signal
|
from blinker import Signal
|
||||||
|
|
||||||
@@ -25,17 +26,17 @@ class CrewAIEventsBus:
|
|||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._instance is None: # prevent race condition
|
if cls._instance is None: # prevent race condition
|
||||||
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._initialize()
|
cls._instance._initialize()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _initialize(self) -> None:
|
def _initialize(self) -> None:
|
||||||
"""Initialize the event bus internal state"""
|
"""Initialize the event bus internal state"""
|
||||||
self._signal = Signal("crewai_event_bus")
|
self._signal = Signal("crewai_event_bus")
|
||||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
|
||||||
|
|
||||||
def on(
|
def on(
|
||||||
self, event_type: Type[EventT]
|
self, event_type: type[EventT]
|
||||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||||
"""
|
"""
|
||||||
Decorator to register an event handler for a specific event type.
|
Decorator to register an event handler for a specific event type.
|
||||||
@@ -82,7 +83,7 @@ class CrewAIEventsBus:
|
|||||||
self._signal.send(source, event=event)
|
self._signal.send(source, event=event)
|
||||||
|
|
||||||
def register_handler(
|
def register_handler(
|
||||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register an event handler for a specific event type"""
|
"""Register an event handler for a specific event type"""
|
||||||
if event_type not in self._handlers:
|
if event_type not in self._handlers:
|
||||||
|
|||||||
@@ -1,15 +1,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, PrivateAttr
|
from pydantic import Field, PrivateAttr
|
||||||
from crewai.llm import LLM
|
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.telemetry.telemetry import Telemetry
|
|
||||||
from crewai.utilities import Logger
|
|
||||||
from crewai.utilities.constants import EMITTER_COLOR
|
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
AgentExecutionCompletedEvent,
|
||||||
|
AgentExecutionStartedEvent,
|
||||||
|
LiteAgentExecutionCompletedEvent,
|
||||||
|
LiteAgentExecutionErrorEvent,
|
||||||
|
LiteAgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.events.types.crew_events import (
|
||||||
|
CrewKickoffCompletedEvent,
|
||||||
|
CrewKickoffFailedEvent,
|
||||||
|
CrewKickoffStartedEvent,
|
||||||
|
CrewTestCompletedEvent,
|
||||||
|
CrewTestFailedEvent,
|
||||||
|
CrewTestResultEvent,
|
||||||
|
CrewTestStartedEvent,
|
||||||
|
CrewTrainCompletedEvent,
|
||||||
|
CrewTrainFailedEvent,
|
||||||
|
CrewTrainStartedEvent,
|
||||||
|
)
|
||||||
from crewai.events.types.knowledge_events import (
|
from crewai.events.types.knowledge_events import (
|
||||||
KnowledgeQueryCompletedEvent,
|
KnowledgeQueryCompletedEvent,
|
||||||
KnowledgeQueryFailedEvent,
|
KnowledgeQueryFailedEvent,
|
||||||
@@ -25,34 +40,21 @@ from crewai.events.types.llm_events import (
|
|||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.llm_guardrail_events import (
|
from crewai.events.types.llm_guardrail_events import (
|
||||||
LLMGuardrailStartedEvent,
|
|
||||||
LLMGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
)
|
LLMGuardrailStartedEvent,
|
||||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
|
||||||
|
|
||||||
from crewai.events.types.agent_events import (
|
|
||||||
AgentExecutionCompletedEvent,
|
|
||||||
AgentExecutionStartedEvent,
|
|
||||||
LiteAgentExecutionCompletedEvent,
|
|
||||||
LiteAgentExecutionErrorEvent,
|
|
||||||
LiteAgentExecutionStartedEvent,
|
|
||||||
)
|
)
|
||||||
from crewai.events.types.logging_events import (
|
from crewai.events.types.logging_events import (
|
||||||
AgentLogsStartedEvent,
|
|
||||||
AgentLogsExecutionEvent,
|
AgentLogsExecutionEvent,
|
||||||
|
AgentLogsStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.crew_events import (
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
CrewKickoffCompletedEvent,
|
from crewai.llm import LLM
|
||||||
CrewKickoffFailedEvent,
|
from crewai.task import Task
|
||||||
CrewKickoffStartedEvent,
|
from crewai.telemetry.telemetry import Telemetry
|
||||||
CrewTestCompletedEvent,
|
from crewai.utilities import Logger
|
||||||
CrewTestFailedEvent,
|
from crewai.utilities.constants import EMITTER_COLOR
|
||||||
CrewTestResultEvent,
|
|
||||||
CrewTestStartedEvent,
|
from .listeners.memory_listener import MemoryListener
|
||||||
CrewTrainCompletedEvent,
|
|
||||||
CrewTrainFailedEvent,
|
|
||||||
CrewTrainStartedEvent,
|
|
||||||
)
|
|
||||||
from .types.flow_events import (
|
from .types.flow_events import (
|
||||||
FlowCreatedEvent,
|
FlowCreatedEvent,
|
||||||
FlowFinishedEvent,
|
FlowFinishedEvent,
|
||||||
@@ -61,26 +63,24 @@ from .types.flow_events import (
|
|||||||
MethodExecutionFinishedEvent,
|
MethodExecutionFinishedEvent,
|
||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
|
from .types.reasoning_events import (
|
||||||
|
AgentReasoningCompletedEvent,
|
||||||
|
AgentReasoningFailedEvent,
|
||||||
|
AgentReasoningStartedEvent,
|
||||||
|
)
|
||||||
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
||||||
from .types.tool_usage_events import (
|
from .types.tool_usage_events import (
|
||||||
ToolUsageErrorEvent,
|
ToolUsageErrorEvent,
|
||||||
ToolUsageFinishedEvent,
|
ToolUsageFinishedEvent,
|
||||||
ToolUsageStartedEvent,
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
from .types.reasoning_events import (
|
|
||||||
AgentReasoningStartedEvent,
|
|
||||||
AgentReasoningCompletedEvent,
|
|
||||||
AgentReasoningFailedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .listeners.memory_listener import MemoryListener
|
|
||||||
|
|
||||||
|
|
||||||
class EventListener(BaseEventListener):
|
class EventListener(BaseEventListener):
|
||||||
_instance = None
|
_instance = None
|
||||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||||
next_chunk = 0
|
next_chunk = 0
|
||||||
text_stream = StringIO()
|
text_stream = StringIO()
|
||||||
knowledge_retrieval_in_progress = False
|
knowledge_retrieval_in_progress = False
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from crewai.events.types.agent_events import (
|
|||||||
AgentExecutionStartedEvent,
|
AgentExecutionStartedEvent,
|
||||||
LiteAgentExecutionCompletedEvent,
|
LiteAgentExecutionCompletedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .types.crew_events import (
|
from .types.crew_events import (
|
||||||
CrewKickoffCompletedEvent,
|
CrewKickoffCompletedEvent,
|
||||||
CrewKickoffFailedEvent,
|
CrewKickoffFailedEvent,
|
||||||
@@ -24,6 +25,14 @@ from .types.flow_events import (
|
|||||||
MethodExecutionFinishedEvent,
|
MethodExecutionFinishedEvent,
|
||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
|
from .types.knowledge_events import (
|
||||||
|
KnowledgeQueryCompletedEvent,
|
||||||
|
KnowledgeQueryFailedEvent,
|
||||||
|
KnowledgeQueryStartedEvent,
|
||||||
|
KnowledgeRetrievalCompletedEvent,
|
||||||
|
KnowledgeRetrievalStartedEvent,
|
||||||
|
KnowledgeSearchQueryFailedEvent,
|
||||||
|
)
|
||||||
from .types.llm_events import (
|
from .types.llm_events import (
|
||||||
LLMCallCompletedEvent,
|
LLMCallCompletedEvent,
|
||||||
LLMCallFailedEvent,
|
LLMCallFailedEvent,
|
||||||
@@ -34,6 +43,21 @@ from .types.llm_guardrail_events import (
|
|||||||
LLMGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
LLMGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
|
from .types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryFailedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemoryRetrievalCompletedEvent,
|
||||||
|
MemoryRetrievalStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
|
from .types.reasoning_events import (
|
||||||
|
AgentReasoningCompletedEvent,
|
||||||
|
AgentReasoningFailedEvent,
|
||||||
|
AgentReasoningStartedEvent,
|
||||||
|
)
|
||||||
from .types.task_events import (
|
from .types.task_events import (
|
||||||
TaskCompletedEvent,
|
TaskCompletedEvent,
|
||||||
TaskFailedEvent,
|
TaskFailedEvent,
|
||||||
@@ -44,30 +68,6 @@ from .types.tool_usage_events import (
|
|||||||
ToolUsageFinishedEvent,
|
ToolUsageFinishedEvent,
|
||||||
ToolUsageStartedEvent,
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
from .types.reasoning_events import (
|
|
||||||
AgentReasoningStartedEvent,
|
|
||||||
AgentReasoningCompletedEvent,
|
|
||||||
AgentReasoningFailedEvent,
|
|
||||||
)
|
|
||||||
from .types.knowledge_events import (
|
|
||||||
KnowledgeRetrievalStartedEvent,
|
|
||||||
KnowledgeRetrievalCompletedEvent,
|
|
||||||
KnowledgeQueryStartedEvent,
|
|
||||||
KnowledgeQueryCompletedEvent,
|
|
||||||
KnowledgeQueryFailedEvent,
|
|
||||||
KnowledgeSearchQueryFailedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .types.memory_events import (
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
MemorySaveCompletedEvent,
|
|
||||||
MemorySaveFailedEvent,
|
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
|
||||||
MemoryQueryFailedEvent,
|
|
||||||
MemoryRetrievalStartedEvent,
|
|
||||||
MemoryRetrievalCompletedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
EventTypes = Union[
|
EventTypes = Union[
|
||||||
CrewKickoffStartedEvent,
|
CrewKickoffStartedEvent,
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
This module contains various event listener implementations
|
This module contains various event listener implementations
|
||||||
for handling memory, tracing, and other event-driven functionality.
|
for handling memory, tracing, and other event-driven functionality.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryFailedEvent,
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
MemoryRetrievalStartedEvent,
|
MemoryRetrievalStartedEvent,
|
||||||
MemoryQueryFailedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
MemorySaveCompletedEvent,
|
MemorySaveCompletedEvent,
|
||||||
MemorySaveFailedEvent,
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Dict, Any
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -13,7 +13,7 @@ class TraceEvent:
|
|||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||||
)
|
)
|
||||||
type: str = ""
|
type: str = ""
|
||||||
event_data: Dict[str, Any] = field(default_factory=dict)
|
event_data: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
This module contains all event types used throughout the CrewAI system
|
This module contains all event types used throughout the CrewAI system
|
||||||
for monitoring and extending agent, crew, task, and tool execution.
|
for monitoring and extending agent, crew, task, and tool execution.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
|
from crewai.events.base_events import BaseEvent
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.events.base_events import BaseEvent
|
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutionStartedEvent(BaseEvent):
|
class AgentExecutionStartedEvent(BaseEvent):
|
||||||
@@ -17,11 +18,11 @@ class AgentExecutionStartedEvent(BaseEvent):
|
|||||||
|
|
||||||
agent: BaseAgent
|
agent: BaseAgent
|
||||||
task: Any
|
task: Any
|
||||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||||
task_prompt: str
|
task_prompt: str
|
||||||
type: str = "agent_execution_started"
|
type: str = "agent_execution_started"
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self):
|
||||||
@@ -45,7 +46,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
|||||||
output: str
|
output: str
|
||||||
type: str = "agent_execution_completed"
|
type: str = "agent_execution_completed"
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self):
|
||||||
@@ -69,7 +70,7 @@ class AgentExecutionErrorEvent(BaseEvent):
|
|||||||
error: str
|
error: str
|
||||||
type: str = "agent_execution_error"
|
type: str = "agent_execution_error"
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self):
|
||||||
@@ -89,18 +90,18 @@ class AgentExecutionErrorEvent(BaseEvent):
|
|||||||
class LiteAgentExecutionStartedEvent(BaseEvent):
|
class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||||
"""Event emitted when a LiteAgent starts executing"""
|
"""Event emitted when a LiteAgent starts executing"""
|
||||||
|
|
||||||
agent_info: Dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||||
messages: Union[str, List[Dict[str, str]]]
|
messages: str | list[dict[str, str]]
|
||||||
type: str = "lite_agent_execution_started"
|
type: str = "lite_agent_execution_started"
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
class LiteAgentExecutionCompletedEvent(BaseEvent):
|
class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||||
"""Event emitted when a LiteAgent completes execution"""
|
"""Event emitted when a LiteAgent completes execution"""
|
||||||
|
|
||||||
agent_info: Dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
output: str
|
output: str
|
||||||
type: str = "lite_agent_execution_completed"
|
type: str = "lite_agent_execution_completed"
|
||||||
|
|
||||||
@@ -108,7 +109,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
|
|||||||
class LiteAgentExecutionErrorEvent(BaseEvent):
|
class LiteAgentExecutionErrorEvent(BaseEvent):
|
||||||
"""Event emitted when a LiteAgent encounters an error during execution"""
|
"""Event emitted when a LiteAgent encounters an error during execution"""
|
||||||
|
|
||||||
agent_info: Dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
error: str
|
error: str
|
||||||
type: str = "lite_agent_execution_error"
|
type: str = "lite_agent_execution_error"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -11,8 +11,8 @@ else:
|
|||||||
class CrewBaseEvent(BaseEvent):
|
class CrewBaseEvent(BaseEvent):
|
||||||
"""Base class for crew events with fingerprint handling"""
|
"""Base class for crew events with fingerprint handling"""
|
||||||
|
|
||||||
crew_name: Optional[str]
|
crew_name: str | None
|
||||||
crew: Optional[Crew] = None
|
crew: Crew | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -38,7 +38,7 @@ class CrewBaseEvent(BaseEvent):
|
|||||||
class CrewKickoffStartedEvent(CrewBaseEvent):
|
class CrewKickoffStartedEvent(CrewBaseEvent):
|
||||||
"""Event emitted when a crew starts execution"""
|
"""Event emitted when a crew starts execution"""
|
||||||
|
|
||||||
inputs: Optional[Dict[str, Any]]
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_kickoff_started"
|
type: str = "crew_kickoff_started"
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
|
|||||||
|
|
||||||
n_iterations: int
|
n_iterations: int
|
||||||
filename: str
|
filename: str
|
||||||
inputs: Optional[Dict[str, Any]]
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_train_started"
|
type: str = "crew_train_started"
|
||||||
|
|
||||||
|
|
||||||
@@ -85,8 +85,8 @@ class CrewTestStartedEvent(CrewBaseEvent):
|
|||||||
"""Event emitted when a crew starts testing"""
|
"""Event emitted when a crew starts testing"""
|
||||||
|
|
||||||
n_iterations: int
|
n_iterations: int
|
||||||
eval_llm: Optional[Union[str, Any]]
|
eval_llm: str | Any | None
|
||||||
inputs: Optional[Dict[str, Any]]
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_test_started"
|
type: str = "crew_test_started"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ class FlowStartedEvent(FlowEvent):
|
|||||||
"""Event emitted when a flow starts execution"""
|
"""Event emitted when a flow starts execution"""
|
||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
inputs: Optional[Dict[str, Any]] = None
|
inputs: dict[str, Any] | None = None
|
||||||
type: str = "flow_started"
|
type: str = "flow_started"
|
||||||
|
|
||||||
|
|
||||||
@@ -32,8 +32,8 @@ class MethodExecutionStartedEvent(FlowEvent):
|
|||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
method_name: str
|
method_name: str
|
||||||
state: Union[Dict[str, Any], BaseModel]
|
state: dict[str, Any] | BaseModel
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
type: str = "method_execution_started"
|
type: str = "method_execution_started"
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
|
|||||||
flow_name: str
|
flow_name: str
|
||||||
method_name: str
|
method_name: str
|
||||||
result: Any = None
|
result: Any = None
|
||||||
state: Union[Dict[str, Any], BaseModel]
|
state: dict[str, Any] | BaseModel
|
||||||
type: str = "method_execution_finished"
|
type: str = "method_execution_finished"
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class FlowFinishedEvent(FlowEvent):
|
|||||||
"""Event emitted when a flow completes execution"""
|
"""Event emitted when a flow completes execution"""
|
||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
result: Optional[Any] = None
|
result: Any | None = None
|
||||||
type: str = "flow_finished"
|
type: str = "flow_finished"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from crewai.events.base_events import BaseEvent
|
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalStartedEvent(BaseEvent):
|
class KnowledgeRetrievalStartedEvent(BaseEvent):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -7,14 +7,14 @@ from crewai.events.base_events import BaseEvent
|
|||||||
|
|
||||||
|
|
||||||
class LLMEventBase(BaseEvent):
|
class LLMEventBase(BaseEvent):
|
||||||
task_name: Optional[str] = None
|
task_name: str | None = None
|
||||||
task_id: Optional[str] = None
|
task_id: str | None = None
|
||||||
|
|
||||||
agent_id: Optional[str] = None
|
agent_id: str | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
|
|
||||||
from_task: Optional[Any] = None
|
from_task: Any | None = None
|
||||||
from_agent: Optional[Any] = None
|
from_agent: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -38,11 +38,11 @@ class LLMCallStartedEvent(LLMEventBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_call_started"
|
type: str = "llm_call_started"
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
messages: Optional[Union[str, List[Dict[str, Any]]]] = None
|
messages: str | list[dict[str, Any]] | None = None
|
||||||
tools: Optional[List[dict[str, Any]]] = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
callbacks: Optional[List[Any]] = None
|
callbacks: list[Any] | None = None
|
||||||
available_functions: Optional[Dict[str, Any]] = None
|
available_functions: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMCallCompletedEvent(LLMEventBase):
|
class LLMCallCompletedEvent(LLMEventBase):
|
||||||
@@ -52,7 +52,7 @@ class LLMCallCompletedEvent(LLMEventBase):
|
|||||||
messages: str | list[dict[str, Any]] | None = None
|
messages: str | list[dict[str, Any]] | None = None
|
||||||
response: Any
|
response: Any
|
||||||
call_type: LLMCallType
|
call_type: LLMCallType
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMCallFailedEvent(LLMEventBase):
|
class LLMCallFailedEvent(LLMEventBase):
|
||||||
@@ -64,13 +64,13 @@ class LLMCallFailedEvent(LLMEventBase):
|
|||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
arguments: str
|
arguments: str
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
function: FunctionCall
|
function: FunctionCall
|
||||||
type: Optional[str] = None
|
type: str | None = None
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
|
|
||||||
@@ -79,4 +79,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
|||||||
|
|
||||||
type: str = "llm_stream_chunk"
|
type: str = "llm_stream_chunk"
|
||||||
chunk: str
|
chunk: str
|
||||||
tool_call: Optional[ToolCall] = None
|
tool_call: ToolCall | None = None
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from inspect import getsource
|
from inspect import getsource
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -13,12 +14,12 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_guardrail_started"
|
type: str = "llm_guardrail_started"
|
||||||
guardrail: Union[str, Callable]
|
guardrail: str | Callable
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
|
||||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||||
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
@@ -41,5 +42,5 @@ class LLMGuardrailCompletedEvent(BaseEvent):
|
|||||||
type: str = "llm_guardrail_completed"
|
type: str = "llm_guardrail_completed"
|
||||||
success: bool
|
success: bool
|
||||||
result: Any
|
result: Any
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -9,7 +11,7 @@ class AgentLogsStartedEvent(BaseEvent):
|
|||||||
"""Event emitted when agent logs should be shown at start"""
|
"""Event emitted when agent logs should be shown at start"""
|
||||||
|
|
||||||
agent_role: str
|
agent_role: str
|
||||||
task_description: Optional[str] = None
|
task_description: str | None = None
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
type: str = "agent_logs_started"
|
type: str = "agent_logs_started"
|
||||||
|
|
||||||
@@ -22,4 +24,4 @@ class AgentLogsExecutionEvent(BaseEvent):
|
|||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
type: str = "agent_logs_execution"
|
type: str = "agent_logs_execution"
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -7,12 +7,12 @@ class MemoryBaseEvent(BaseEvent):
|
|||||||
"""Base event for memory operations"""
|
"""Base event for memory operations"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
task_id: Optional[str] = None
|
task_id: str | None = None
|
||||||
task_name: Optional[str] = None
|
task_name: str | None = None
|
||||||
from_task: Optional[Any] = None
|
from_task: Any | None = None
|
||||||
from_agent: Optional[Any] = None
|
from_agent: Any | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
agent_id: Optional[str] = None
|
agent_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -26,7 +26,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
|
|||||||
type: str = "memory_query_started"
|
type: str = "memory_query_started"
|
||||||
query: str
|
query: str
|
||||||
limit: int
|
limit: int
|
||||||
score_threshold: Optional[float] = None
|
score_threshold: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||||
@@ -36,7 +36,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
|||||||
query: str
|
query: str
|
||||||
results: Any
|
results: Any
|
||||||
limit: int
|
limit: int
|
||||||
score_threshold: Optional[float] = None
|
score_threshold: float | None = None
|
||||||
query_time_ms: float
|
query_time_ms: float
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
|
|||||||
type: str = "memory_query_failed"
|
type: str = "memory_query_failed"
|
||||||
query: str
|
query: str
|
||||||
limit: int
|
limit: int
|
||||||
score_threshold: Optional[float] = None
|
score_threshold: float | None = None
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
@@ -54,9 +54,9 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
|
|||||||
"""Event emitted when a memory save operation is started"""
|
"""Event emitted when a memory save operation is started"""
|
||||||
|
|
||||||
type: str = "memory_save_started"
|
type: str = "memory_save_started"
|
||||||
value: Optional[str] = None
|
value: str | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||||
@@ -64,8 +64,8 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
|
|||||||
|
|
||||||
type: str = "memory_save_completed"
|
type: str = "memory_save_completed"
|
||||||
value: str
|
value: str
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
save_time_ms: float
|
save_time_ms: float
|
||||||
|
|
||||||
|
|
||||||
@@ -73,9 +73,9 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
|
|||||||
"""Event emitted when a memory save operation fails"""
|
"""Event emitted when a memory save operation fails"""
|
||||||
|
|
||||||
type: str = "memory_save_failed"
|
type: str = "memory_save_failed"
|
||||||
value: Optional[str] = None
|
value: str | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
@@ -83,13 +83,13 @@ class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
|||||||
"""Event emitted when memory retrieval for a task prompt starts"""
|
"""Event emitted when memory retrieval for a task prompt starts"""
|
||||||
|
|
||||||
type: str = "memory_retrieval_started"
|
type: str = "memory_retrieval_started"
|
||||||
task_id: Optional[str] = None
|
task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
||||||
|
|
||||||
type: str = "memory_retrieval_completed"
|
type: str = "memory_retrieval_completed"
|
||||||
task_id: Optional[str] = None
|
task_id: str | None = None
|
||||||
memory_content: str
|
memory_content: str
|
||||||
retrieval_time_ms: float
|
retrieval_time_ms: float
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class ReasoningEvent(BaseEvent):
|
class ReasoningEvent(BaseEvent):
|
||||||
@@ -9,10 +10,10 @@ class ReasoningEvent(BaseEvent):
|
|||||||
attempt: int = 1
|
attempt: int = 1
|
||||||
agent_role: str
|
agent_role: str
|
||||||
task_id: str
|
task_id: str
|
||||||
task_name: Optional[str] = None
|
task_name: str | None = None
|
||||||
from_task: Optional[Any] = None
|
from_task: Any | None = None
|
||||||
agent_id: Optional[str] = None
|
agent_id: str | None = None
|
||||||
from_agent: Optional[Any] = None
|
from_agent: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tasks.task_output import TaskOutput
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
class TaskStartedEvent(BaseEvent):
|
class TaskStartedEvent(BaseEvent):
|
||||||
"""Event emitted when a task starts"""
|
"""Event emitted when a task starts"""
|
||||||
|
|
||||||
type: str = "task_started"
|
type: str = "task_started"
|
||||||
context: Optional[str]
|
context: str | None
|
||||||
task: Optional[Any] = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -29,7 +29,7 @@ class TaskCompletedEvent(BaseEvent):
|
|||||||
|
|
||||||
output: TaskOutput
|
output: TaskOutput
|
||||||
type: str = "task_completed"
|
type: str = "task_completed"
|
||||||
task: Optional[Any] = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -49,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
|
|||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "task_failed"
|
type: str = "task_failed"
|
||||||
task: Optional[Any] = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -69,7 +69,7 @@ class TaskEvaluationEvent(BaseEvent):
|
|||||||
|
|
||||||
type: str = "task_evaluation"
|
type: str = "task_evaluation"
|
||||||
evaluation_type: str
|
evaluation_type: str
|
||||||
task: Optional[Any] = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -7,21 +10,21 @@ from crewai.events.base_events import BaseEvent
|
|||||||
class ToolUsageEvent(BaseEvent):
|
class ToolUsageEvent(BaseEvent):
|
||||||
"""Base event for tool usage tracking"""
|
"""Base event for tool usage tracking"""
|
||||||
|
|
||||||
agent_key: Optional[str] = None
|
agent_key: str | None = None
|
||||||
agent_role: Optional[str] = None
|
agent_role: str | None = None
|
||||||
agent_id: Optional[str] = None
|
agent_id: str | None = None
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: Dict[str, Any] | str
|
tool_args: dict[str, Any] | str
|
||||||
tool_class: Optional[str] = None
|
tool_class: str | None = None
|
||||||
run_attempts: int | None = None
|
run_attempts: int | None = None
|
||||||
delegations: int | None = None
|
delegations: int | None = None
|
||||||
agent: Optional[Any] = None
|
agent: Any | None = None
|
||||||
task_name: Optional[str] = None
|
task_name: str | None = None
|
||||||
task_id: Optional[str] = None
|
task_id: str | None = None
|
||||||
from_task: Optional[Any] = None
|
from_task: Any | None = None
|
||||||
from_agent: Optional[Any] = None
|
from_agent: Any | None = None
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@@ -81,9 +84,9 @@ class ToolExecutionErrorEvent(BaseEvent):
|
|||||||
error: Any
|
error: Any
|
||||||
type: str = "tool_execution_error"
|
type: str = "tool_execution_error"
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: Dict[str, Any]
|
tool_args: dict[str, Any]
|
||||||
tool_class: Callable
|
tool_class: Callable
|
||||||
agent: Optional[Any] = None
|
agent: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
from rich.live import Live
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
from rich.syntax import Syntax
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich.tree import Tree
|
from rich.tree import Tree
|
||||||
from rich.live import Live
|
|
||||||
from rich.syntax import Syntax
|
|
||||||
|
|
||||||
|
|
||||||
class ConsoleFormatter:
|
class ConsoleFormatter:
|
||||||
current_crew_tree: Optional[Tree] = None
|
current_crew_tree: ClassVar[Tree | None] = None
|
||||||
current_task_branch: Optional[Tree] = None
|
current_task_branch: ClassVar[Tree | None] = None
|
||||||
current_agent_branch: Optional[Tree] = None
|
current_agent_branch: ClassVar[Tree | None] = None
|
||||||
current_tool_branch: Optional[Tree] = None
|
current_tool_branch: ClassVar[Tree | None] = None
|
||||||
current_flow_tree: Optional[Tree] = None
|
current_flow_tree: ClassVar[Tree | None] = None
|
||||||
current_method_branch: Optional[Tree] = None
|
current_method_branch: ClassVar[Tree | None] = None
|
||||||
current_lite_agent_branch: Optional[Tree] = None
|
current_lite_agent_branch: ClassVar[Tree | None] = None
|
||||||
tool_usage_counts: Dict[str, int] = {}
|
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||||
current_reasoning_branch: Optional[Tree] = None # Track reasoning status
|
current_reasoning_branch: ClassVar[Tree | None] = None # Track reasoning status
|
||||||
_live_paused: bool = False
|
_live_paused: ClassVar[bool] = False
|
||||||
current_llm_tool_tree: Optional[Tree] = None
|
current_llm_tool_tree: ClassVar[Tree | None] = None
|
||||||
|
|
||||||
def __init__(self, verbose: bool = False):
|
def __init__(self, verbose: bool = False):
|
||||||
self.console = Console(width=None)
|
self.console = Console(width=None)
|
||||||
@@ -29,7 +29,7 @@ class ConsoleFormatter:
|
|||||||
# instance so the previous render is replaced instead of writing a new one.
|
# instance so the previous render is replaced instead of writing a new one.
|
||||||
# Once any non-Tree renderable is printed we stop the Live session so the
|
# Once any non-Tree renderable is printed we stop the Live session so the
|
||||||
# final Tree persists on the terminal.
|
# final Tree persists on the terminal.
|
||||||
self._live: Optional[Live] = None
|
self._live: Live | None = None
|
||||||
|
|
||||||
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
||||||
"""Create a standardized panel with consistent styling."""
|
"""Create a standardized panel with consistent styling."""
|
||||||
@@ -45,7 +45,7 @@ class ConsoleFormatter:
|
|||||||
title: str,
|
title: str,
|
||||||
name: str,
|
name: str,
|
||||||
status_style: str = "blue",
|
status_style: str = "blue",
|
||||||
tool_args: Dict[str, Any] | str = "",
|
tool_args: dict[str, Any] | str = "",
|
||||||
**fields,
|
**fields,
|
||||||
) -> Text:
|
) -> Text:
|
||||||
"""Create standardized status content with consistent formatting."""
|
"""Create standardized status content with consistent formatting."""
|
||||||
@@ -70,7 +70,7 @@ class ConsoleFormatter:
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
name: str,
|
name: str,
|
||||||
style: str = "blue",
|
style: str = "blue",
|
||||||
status: Optional[str] = None,
|
status: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update tree label with consistent formatting."""
|
"""Update tree label with consistent formatting."""
|
||||||
label = Text()
|
label = Text()
|
||||||
@@ -156,7 +156,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_crew_tree(
|
def update_crew_tree(
|
||||||
self,
|
self,
|
||||||
tree: Optional[Tree],
|
tree: Tree | None,
|
||||||
crew_name: str,
|
crew_name: str,
|
||||||
source_id: str,
|
source_id: str,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
@@ -196,7 +196,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
self.print_panel(content, title, style)
|
self.print_panel(content, title, style)
|
||||||
|
|
||||||
def create_crew_tree(self, crew_name: str, source_id: str) -> Optional[Tree]:
|
def create_crew_tree(self, crew_name: str, source_id: str) -> Tree | None:
|
||||||
"""Create and initialize a new crew tree with initial status."""
|
"""Create and initialize a new crew tree with initial status."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -220,8 +220,8 @@ class ConsoleFormatter:
|
|||||||
return tree
|
return tree
|
||||||
|
|
||||||
def create_task_branch(
|
def create_task_branch(
|
||||||
self, crew_tree: Optional[Tree], task_id: str, task_name: Optional[str] = None
|
self, crew_tree: Tree | None, task_id: str, task_name: str | None = None
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Create and initialize a task branch."""
|
"""Create and initialize a task branch."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -255,11 +255,11 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_task_status(
|
def update_task_status(
|
||||||
self,
|
self,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
agent_role: str,
|
agent_role: str,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
task_name: Optional[str] = None,
|
task_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update task status in the tree."""
|
"""Update task status in the tree."""
|
||||||
if not self.verbose or crew_tree is None:
|
if not self.verbose or crew_tree is None:
|
||||||
@@ -306,8 +306,8 @@ class ConsoleFormatter:
|
|||||||
self.print_panel(content, panel_title, style)
|
self.print_panel(content, panel_title, style)
|
||||||
|
|
||||||
def create_agent_branch(
|
def create_agent_branch(
|
||||||
self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree]
|
self, task_branch: Tree | None, agent_role: str, crew_tree: Tree | None
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Create and initialize an agent branch."""
|
"""Create and initialize an agent branch."""
|
||||||
if not self.verbose or not task_branch or not crew_tree:
|
if not self.verbose or not task_branch or not crew_tree:
|
||||||
return None
|
return None
|
||||||
@@ -325,9 +325,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_agent_status(
|
def update_agent_status(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
agent_role: str,
|
agent_role: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update agent status in the tree."""
|
"""Update agent status in the tree."""
|
||||||
@@ -336,7 +336,7 @@ class ConsoleFormatter:
|
|||||||
# altering the tree. Keeping it a no-op avoids duplicate status lines.
|
# altering the tree. Keeping it a no-op avoids duplicate status lines.
|
||||||
return
|
return
|
||||||
|
|
||||||
def create_flow_tree(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
def create_flow_tree(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||||
"""Create and initialize a flow tree."""
|
"""Create and initialize a flow tree."""
|
||||||
content = self.create_status_content(
|
content = self.create_status_content(
|
||||||
"Starting Flow Execution", flow_name, "blue", ID=flow_id
|
"Starting Flow Execution", flow_name, "blue", ID=flow_id
|
||||||
@@ -356,7 +356,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
return flow_tree
|
return flow_tree
|
||||||
|
|
||||||
def start_flow(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
def start_flow(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||||
"""Initialize a flow execution tree."""
|
"""Initialize a flow execution tree."""
|
||||||
flow_tree = Tree("")
|
flow_tree = Tree("")
|
||||||
flow_label = Text()
|
flow_label = Text()
|
||||||
@@ -376,7 +376,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_flow_status(
|
def update_flow_status(
|
||||||
self,
|
self,
|
||||||
flow_tree: Optional[Tree],
|
flow_tree: Tree | None,
|
||||||
flow_name: str,
|
flow_name: str,
|
||||||
flow_id: str,
|
flow_id: str,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
@@ -423,11 +423,11 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_method_status(
|
def update_method_status(
|
||||||
self,
|
self,
|
||||||
method_branch: Optional[Tree],
|
method_branch: Tree | None,
|
||||||
flow_tree: Optional[Tree],
|
flow_tree: Tree | None,
|
||||||
method_name: str,
|
method_name: str,
|
||||||
status: str = "running",
|
status: str = "running",
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Update method status in the flow tree."""
|
"""Update method status in the flow tree."""
|
||||||
if not flow_tree:
|
if not flow_tree:
|
||||||
return None
|
return None
|
||||||
@@ -480,7 +480,7 @@ class ConsoleFormatter:
|
|||||||
def handle_llm_tool_usage_started(
|
def handle_llm_tool_usage_started(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_args: Dict[str, Any] | str,
|
tool_args: dict[str, Any] | str,
|
||||||
):
|
):
|
||||||
# Create status content for the tool usage
|
# Create status content for the tool usage
|
||||||
content = self.create_status_content(
|
content = self.create_status_content(
|
||||||
@@ -520,11 +520,11 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_tool_usage_started(
|
def handle_tool_usage_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
tool_args: Dict[str, Any] | str = "",
|
tool_args: dict[str, Any] | str = "",
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Handle tool usage started event."""
|
"""Handle tool usage started event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -569,9 +569,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_tool_usage_finished(
|
def handle_tool_usage_finished(
|
||||||
self,
|
self,
|
||||||
tool_branch: Optional[Tree],
|
tool_branch: Tree | None,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle tool usage finished event."""
|
"""Handle tool usage finished event."""
|
||||||
if not self.verbose or tool_branch is None:
|
if not self.verbose or tool_branch is None:
|
||||||
@@ -600,10 +600,10 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_tool_usage_error(
|
def handle_tool_usage_error(
|
||||||
self,
|
self,
|
||||||
tool_branch: Optional[Tree],
|
tool_branch: Tree | None,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
error: str,
|
error: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle tool usage error event."""
|
"""Handle tool usage error event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -631,9 +631,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_llm_call_started(
|
def handle_llm_call_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Handle LLM call started event."""
|
"""Handle LLM call started event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -672,9 +672,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_llm_call_completed(
|
def handle_llm_call_completed(
|
||||||
self,
|
self,
|
||||||
tool_branch: Optional[Tree],
|
tool_branch: Tree | None,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle LLM call completed event."""
|
"""Handle LLM call completed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -736,7 +736,7 @@ class ConsoleFormatter:
|
|||||||
self.print()
|
self.print()
|
||||||
|
|
||||||
def handle_llm_call_failed(
|
def handle_llm_call_failed(
|
||||||
self, tool_branch: Optional[Tree], error: str, crew_tree: Optional[Tree]
|
self, tool_branch: Tree | None, error: str, crew_tree: Tree | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle LLM call failed event."""
|
"""Handle LLM call failed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -789,7 +789,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_crew_test_started(
|
def handle_crew_test_started(
|
||||||
self, crew_name: str, source_id: str, n_iterations: int
|
self, crew_name: str, source_id: str, n_iterations: int
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Handle crew test started event."""
|
"""Handle crew test started event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -823,7 +823,7 @@ class ConsoleFormatter:
|
|||||||
return test_tree
|
return test_tree
|
||||||
|
|
||||||
def handle_crew_test_completed(
|
def handle_crew_test_completed(
|
||||||
self, flow_tree: Optional[Tree], crew_name: str
|
self, flow_tree: Tree | None, crew_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle crew test completed event."""
|
"""Handle crew test completed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -913,7 +913,7 @@ class ConsoleFormatter:
|
|||||||
self.print_panel(failure_content, "Test Failure", "red")
|
self.print_panel(failure_content, "Test Failure", "red")
|
||||||
self.print()
|
self.print()
|
||||||
|
|
||||||
def create_lite_agent_branch(self, lite_agent_role: str) -> Optional[Tree]:
|
def create_lite_agent_branch(self, lite_agent_role: str) -> Tree | None:
|
||||||
"""Create and initialize a lite agent branch."""
|
"""Create and initialize a lite agent branch."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -935,10 +935,10 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def update_lite_agent_status(
|
def update_lite_agent_status(
|
||||||
self,
|
self,
|
||||||
lite_agent_branch: Optional[Tree],
|
lite_agent_branch: Tree | None,
|
||||||
lite_agent_role: str,
|
lite_agent_role: str,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
**fields: Dict[str, Any],
|
**fields: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update lite agent status in the tree."""
|
"""Update lite agent status in the tree."""
|
||||||
if not self.verbose or lite_agent_branch is None:
|
if not self.verbose or lite_agent_branch is None:
|
||||||
@@ -981,7 +981,7 @@ class ConsoleFormatter:
|
|||||||
lite_agent_role: str,
|
lite_agent_role: str,
|
||||||
status: str = "started",
|
status: str = "started",
|
||||||
error: Any = None,
|
error: Any = None,
|
||||||
**fields: Dict[str, Any],
|
**fields: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle lite agent execution events with consistent formatting."""
|
"""Handle lite agent execution events with consistent formatting."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1006,9 +1006,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_retrieval_started(
|
def handle_knowledge_retrieval_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Handle knowledge retrieval started event."""
|
"""Handle knowledge retrieval started event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -1034,13 +1034,13 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_retrieval_completed(
|
def handle_knowledge_retrieval_completed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
retrieved_knowledge: Any,
|
retrieved_knowledge: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle knowledge retrieval completed event."""
|
"""Handle knowledge retrieval completed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
@@ -1062,7 +1062,7 @@ class ConsoleFormatter:
|
|||||||
)
|
)
|
||||||
self.print(knowledge_panel)
|
self.print(knowledge_panel)
|
||||||
self.print()
|
self.print()
|
||||||
return None
|
return
|
||||||
|
|
||||||
knowledge_branch_found = False
|
knowledge_branch_found = False
|
||||||
for child in branch_to_use.children:
|
for child in branch_to_use.children:
|
||||||
@@ -1111,18 +1111,18 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_query_started(
|
def handle_knowledge_query_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
task_prompt: str,
|
task_prompt: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle knowledge query generated event."""
|
"""Handle knowledge query generated event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
if branch_to_use is None or tree_to_use is None:
|
if branch_to_use is None or tree_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
query_branch = branch_to_use.add("")
|
query_branch = branch_to_use.add("")
|
||||||
self.update_tree_label(
|
self.update_tree_label(
|
||||||
@@ -1134,9 +1134,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_query_failed(
|
def handle_knowledge_query_failed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
error: str,
|
error: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle knowledge query failed event."""
|
"""Handle knowledge query failed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1159,18 +1159,18 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_query_completed(
|
def handle_knowledge_query_completed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle knowledge query completed event."""
|
"""Handle knowledge query completed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
|
|
||||||
if branch_to_use is None or tree_to_use is None:
|
if branch_to_use is None or tree_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
query_branch = branch_to_use.add("")
|
query_branch = branch_to_use.add("")
|
||||||
self.update_tree_label(query_branch, "✅", "Knowledge Query Completed", "green")
|
self.update_tree_label(query_branch, "✅", "Knowledge Query Completed", "green")
|
||||||
@@ -1180,9 +1180,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_knowledge_search_query_failed(
|
def handle_knowledge_search_query_failed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
error: str,
|
error: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle knowledge search query failed event."""
|
"""Handle knowledge search query failed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1207,10 +1207,10 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_reasoning_started(
|
def handle_reasoning_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
attempt: int,
|
attempt: int,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
"""Handle agent reasoning started (or refinement) event."""
|
"""Handle agent reasoning started (or refinement) event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
@@ -1249,7 +1249,7 @@ class ConsoleFormatter:
|
|||||||
self,
|
self,
|
||||||
plan: str,
|
plan: str,
|
||||||
ready: bool,
|
ready: bool,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle agent reasoning completed event."""
|
"""Handle agent reasoning completed event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1292,7 +1292,7 @@ class ConsoleFormatter:
|
|||||||
def handle_reasoning_failed(
|
def handle_reasoning_failed(
|
||||||
self,
|
self,
|
||||||
error: str,
|
error: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle agent reasoning failure event."""
|
"""Handle agent reasoning failure event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1329,7 +1329,7 @@ class ConsoleFormatter:
|
|||||||
def handle_agent_logs_started(
|
def handle_agent_logs_started(
|
||||||
self,
|
self,
|
||||||
agent_role: str,
|
agent_role: str,
|
||||||
task_description: Optional[str] = None,
|
task_description: str | None = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle agent logs started event."""
|
"""Handle agent logs started event."""
|
||||||
@@ -1367,10 +1367,11 @@ class ConsoleFormatter:
|
|||||||
if not verbose:
|
if not verbose:
|
||||||
return
|
return
|
||||||
|
|
||||||
from crewai.agents.parser import AgentAction, AgentFinish
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from crewai.agents.parser import AgentAction, AgentFinish
|
||||||
|
|
||||||
agent_role = agent_role.partition("\n")[0]
|
agent_role = agent_role.partition("\n")[0]
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
@@ -1473,9 +1474,9 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_retrieval_started(
|
def handle_memory_retrieval_started(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> Optional[Tree]:
|
) -> Tree | None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1497,13 +1498,13 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_retrieval_completed(
|
def handle_memory_retrieval_completed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
memory_content: str,
|
memory_content: str,
|
||||||
retrieval_time_ms: float,
|
retrieval_time_ms: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
@@ -1528,7 +1529,7 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
if branch_to_use is None or tree_to_use is None:
|
if branch_to_use is None or tree_to_use is None:
|
||||||
add_panel()
|
add_panel()
|
||||||
return None
|
return
|
||||||
|
|
||||||
memory_branch_found = False
|
memory_branch_found = False
|
||||||
for child in branch_to_use.children:
|
for child in branch_to_use.children:
|
||||||
@@ -1565,13 +1566,13 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_query_completed(
|
def handle_memory_query_completed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
query_time_ms: float,
|
query_time_ms: float,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
@@ -1580,7 +1581,7 @@ class ConsoleFormatter:
|
|||||||
branch_to_use = tree_to_use
|
branch_to_use = tree_to_use
|
||||||
|
|
||||||
if branch_to_use is None:
|
if branch_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
memory_type = source_type.replace("_", " ").title()
|
memory_type = source_type.replace("_", " ").title()
|
||||||
|
|
||||||
@@ -1598,13 +1599,13 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_query_failed(
|
def handle_memory_query_failed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
error: str,
|
error: str,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
@@ -1613,7 +1614,7 @@ class ConsoleFormatter:
|
|||||||
branch_to_use = tree_to_use
|
branch_to_use = tree_to_use
|
||||||
|
|
||||||
if branch_to_use is None:
|
if branch_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
memory_type = source_type.replace("_", " ").title()
|
memory_type = source_type.replace("_", " ").title()
|
||||||
|
|
||||||
@@ -1630,16 +1631,16 @@ class ConsoleFormatter:
|
|||||||
break
|
break
|
||||||
|
|
||||||
def handle_memory_save_started(
|
def handle_memory_save_started(
|
||||||
self, agent_branch: Optional[Tree], crew_tree: Optional[Tree]
|
self, agent_branch: Tree | None, crew_tree: Tree | None
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
|
|
||||||
if tree_to_use is None:
|
if tree_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
for child in tree_to_use.children:
|
for child in tree_to_use.children:
|
||||||
if "Memory Update" in str(child.label):
|
if "Memory Update" in str(child.label):
|
||||||
@@ -1655,19 +1656,19 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_save_completed(
|
def handle_memory_save_completed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
save_time_ms: float,
|
save_time_ms: float,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
|
|
||||||
if tree_to_use is None:
|
if tree_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
memory_type = source_type.replace("_", " ").title()
|
memory_type = source_type.replace("_", " ").title()
|
||||||
content = f"✅ {memory_type} Memory Saved ({save_time_ms:.2f}ms)"
|
content = f"✅ {memory_type} Memory Saved ({save_time_ms:.2f}ms)"
|
||||||
@@ -1685,19 +1686,19 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
def handle_memory_save_failed(
|
def handle_memory_save_failed(
|
||||||
self,
|
self,
|
||||||
agent_branch: Optional[Tree],
|
agent_branch: Tree | None,
|
||||||
error: str,
|
error: str,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Tree | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
return None
|
return
|
||||||
|
|
||||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||||
tree_to_use = branch_to_use or crew_tree
|
tree_to_use = branch_to_use or crew_tree
|
||||||
|
|
||||||
if branch_to_use is None or tree_to_use is None:
|
if branch_to_use is None or tree_to_use is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
memory_type = source_type.replace("_", " ").title()
|
memory_type = source_type.replace("_", " ").title()
|
||||||
content = f"❌ {memory_type} Memory Save Failed"
|
content = f"❌ {memory_type} Memory Save Failed"
|
||||||
@@ -1738,7 +1739,7 @@ class ConsoleFormatter:
|
|||||||
def handle_guardrail_completed(
|
def handle_guardrail_completed(
|
||||||
self,
|
self,
|
||||||
success: bool,
|
success: bool,
|
||||||
error: Optional[str],
|
error: str | None,
|
||||||
retry_count: int,
|
retry_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display guardrail evaluation result.
|
"""Display guardrail evaluation result.
|
||||||
|
|||||||
@@ -1,40 +1,39 @@
|
|||||||
from crewai.experimental.evaluation import (
|
from crewai.experimental.evaluation import (
|
||||||
|
AgentEvaluationResult,
|
||||||
|
AgentEvaluator,
|
||||||
BaseEvaluator,
|
BaseEvaluator,
|
||||||
EvaluationScore,
|
EvaluationScore,
|
||||||
MetricCategory,
|
|
||||||
AgentEvaluationResult,
|
|
||||||
SemanticQualityEvaluator,
|
|
||||||
GoalAlignmentEvaluator,
|
|
||||||
ReasoningEfficiencyEvaluator,
|
|
||||||
ToolSelectionEvaluator,
|
|
||||||
ParameterExtractionEvaluator,
|
|
||||||
ToolInvocationEvaluator,
|
|
||||||
EvaluationTraceCallback,
|
EvaluationTraceCallback,
|
||||||
create_evaluation_callbacks,
|
|
||||||
AgentEvaluator,
|
|
||||||
create_default_evaluator,
|
|
||||||
ExperimentRunner,
|
|
||||||
ExperimentResults,
|
|
||||||
ExperimentResult,
|
ExperimentResult,
|
||||||
|
ExperimentResults,
|
||||||
|
ExperimentRunner,
|
||||||
|
GoalAlignmentEvaluator,
|
||||||
|
MetricCategory,
|
||||||
|
ParameterExtractionEvaluator,
|
||||||
|
ReasoningEfficiencyEvaluator,
|
||||||
|
SemanticQualityEvaluator,
|
||||||
|
ToolInvocationEvaluator,
|
||||||
|
ToolSelectionEvaluator,
|
||||||
|
create_default_evaluator,
|
||||||
|
create_evaluation_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AgentEvaluationResult",
|
||||||
|
"AgentEvaluator",
|
||||||
"BaseEvaluator",
|
"BaseEvaluator",
|
||||||
"EvaluationScore",
|
"EvaluationScore",
|
||||||
"MetricCategory",
|
|
||||||
"AgentEvaluationResult",
|
|
||||||
"SemanticQualityEvaluator",
|
|
||||||
"GoalAlignmentEvaluator",
|
|
||||||
"ReasoningEfficiencyEvaluator",
|
|
||||||
"ToolSelectionEvaluator",
|
|
||||||
"ParameterExtractionEvaluator",
|
|
||||||
"ToolInvocationEvaluator",
|
|
||||||
"EvaluationTraceCallback",
|
"EvaluationTraceCallback",
|
||||||
"create_evaluation_callbacks",
|
"ExperimentResult",
|
||||||
"AgentEvaluator",
|
|
||||||
"create_default_evaluator",
|
|
||||||
"ExperimentRunner",
|
|
||||||
"ExperimentResults",
|
"ExperimentResults",
|
||||||
"ExperimentResult"
|
"ExperimentRunner",
|
||||||
]
|
"GoalAlignmentEvaluator",
|
||||||
|
"MetricCategory",
|
||||||
|
"ParameterExtractionEvaluator",
|
||||||
|
"ReasoningEfficiencyEvaluator",
|
||||||
|
"SemanticQualityEvaluator",
|
||||||
|
"ToolInvocationEvaluator",
|
||||||
|
"ToolSelectionEvaluator",
|
||||||
|
"create_default_evaluator",
|
||||||
|
"create_evaluation_callbacks",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,51 +1,47 @@
|
|||||||
|
from crewai.experimental.evaluation.agent_evaluator import (
|
||||||
|
AgentEvaluator,
|
||||||
|
create_default_evaluator,
|
||||||
|
)
|
||||||
from crewai.experimental.evaluation.base_evaluator import (
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
|
AgentEvaluationResult,
|
||||||
BaseEvaluator,
|
BaseEvaluator,
|
||||||
EvaluationScore,
|
EvaluationScore,
|
||||||
MetricCategory,
|
MetricCategory,
|
||||||
AgentEvaluationResult
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.experimental.evaluation.metrics import (
|
|
||||||
SemanticQualityEvaluator,
|
|
||||||
GoalAlignmentEvaluator,
|
|
||||||
ReasoningEfficiencyEvaluator,
|
|
||||||
ToolSelectionEvaluator,
|
|
||||||
ParameterExtractionEvaluator,
|
|
||||||
ToolInvocationEvaluator
|
|
||||||
)
|
|
||||||
|
|
||||||
from crewai.experimental.evaluation.evaluation_listener import (
|
from crewai.experimental.evaluation.evaluation_listener import (
|
||||||
EvaluationTraceCallback,
|
EvaluationTraceCallback,
|
||||||
create_evaluation_callbacks
|
create_evaluation_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.experimental.evaluation.agent_evaluator import (
|
|
||||||
AgentEvaluator,
|
|
||||||
create_default_evaluator
|
|
||||||
)
|
|
||||||
|
|
||||||
from crewai.experimental.evaluation.experiment import (
|
from crewai.experimental.evaluation.experiment import (
|
||||||
ExperimentRunner,
|
ExperimentResult,
|
||||||
ExperimentResults,
|
ExperimentResults,
|
||||||
ExperimentResult
|
ExperimentRunner,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.metrics import (
|
||||||
|
GoalAlignmentEvaluator,
|
||||||
|
ParameterExtractionEvaluator,
|
||||||
|
ReasoningEfficiencyEvaluator,
|
||||||
|
SemanticQualityEvaluator,
|
||||||
|
ToolInvocationEvaluator,
|
||||||
|
ToolSelectionEvaluator,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AgentEvaluationResult",
|
||||||
|
"AgentEvaluator",
|
||||||
"BaseEvaluator",
|
"BaseEvaluator",
|
||||||
"EvaluationScore",
|
"EvaluationScore",
|
||||||
"MetricCategory",
|
|
||||||
"AgentEvaluationResult",
|
|
||||||
"SemanticQualityEvaluator",
|
|
||||||
"GoalAlignmentEvaluator",
|
|
||||||
"ReasoningEfficiencyEvaluator",
|
|
||||||
"ToolSelectionEvaluator",
|
|
||||||
"ParameterExtractionEvaluator",
|
|
||||||
"ToolInvocationEvaluator",
|
|
||||||
"EvaluationTraceCallback",
|
"EvaluationTraceCallback",
|
||||||
"create_evaluation_callbacks",
|
"ExperimentResult",
|
||||||
"AgentEvaluator",
|
|
||||||
"create_default_evaluator",
|
|
||||||
"ExperimentRunner",
|
|
||||||
"ExperimentResults",
|
"ExperimentResults",
|
||||||
"ExperimentResult"
|
"ExperimentRunner",
|
||||||
|
"GoalAlignmentEvaluator",
|
||||||
|
"MetricCategory",
|
||||||
|
"ParameterExtractionEvaluator",
|
||||||
|
"ReasoningEfficiencyEvaluator",
|
||||||
|
"SemanticQualityEvaluator",
|
||||||
|
"ToolInvocationEvaluator",
|
||||||
|
"ToolSelectionEvaluator",
|
||||||
|
"create_default_evaluator",
|
||||||
|
"create_evaluation_callbacks",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,34 +1,32 @@
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Optional
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.experimental.evaluation.base_evaluator import (
|
|
||||||
AgentEvaluationResult,
|
|
||||||
AggregationStrategy,
|
|
||||||
)
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.task import Task
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentEvaluationStartedEvent,
|
|
||||||
AgentEvaluationCompletedEvent,
|
AgentEvaluationCompletedEvent,
|
||||||
AgentEvaluationFailedEvent,
|
AgentEvaluationFailedEvent,
|
||||||
|
AgentEvaluationStartedEvent,
|
||||||
|
LiteAgentExecutionCompletedEvent,
|
||||||
)
|
)
|
||||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
|
||||||
from crewai.events.types.task_events import TaskCompletedEvent
|
from crewai.events.types.task_events import TaskCompletedEvent
|
||||||
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
|
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||||
from crewai.experimental.evaluation.base_evaluator import (
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
AgentAggregatedEvaluationResult,
|
AgentAggregatedEvaluationResult,
|
||||||
|
AgentEvaluationResult,
|
||||||
|
AggregationStrategy,
|
||||||
EvaluationScore,
|
EvaluationScore,
|
||||||
MetricCategory,
|
MetricCategory,
|
||||||
)
|
)
|
||||||
|
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
class ExecutionState:
|
class ExecutionState:
|
||||||
current_agent_id: Optional[str] = None
|
current_agent_id: str | None = None
|
||||||
current_task_id: Optional[str] = None
|
current_task_id: str | None = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.traces = {}
|
self.traces = {}
|
||||||
@@ -284,7 +282,7 @@ class AgentEvaluator:
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
self.console_formatter.print(
|
self.console_formatter.print(
|
||||||
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
|
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -340,11 +338,11 @@ class AgentEvaluator:
|
|||||||
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||||
from crewai.experimental.evaluation import (
|
from crewai.experimental.evaluation import (
|
||||||
GoalAlignmentEvaluator,
|
GoalAlignmentEvaluator,
|
||||||
SemanticQualityEvaluator,
|
|
||||||
ToolSelectionEvaluator,
|
|
||||||
ParameterExtractionEvaluator,
|
ParameterExtractionEvaluator,
|
||||||
ToolInvocationEvaluator,
|
|
||||||
ReasoningEfficiencyEvaluator,
|
ReasoningEfficiencyEvaluator,
|
||||||
|
SemanticQualityEvaluator,
|
||||||
|
ToolInvocationEvaluator,
|
||||||
|
ToolSelectionEvaluator,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluators = [
|
evaluators = [
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import abc
|
import abc
|
||||||
import enum
|
import enum
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.llm import BaseLLM
|
from crewai.llm import BaseLLM
|
||||||
|
from crewai.task import Task
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
class MetricCategory(enum.Enum):
|
class MetricCategory(enum.Enum):
|
||||||
GOAL_ALIGNMENT = "goal_alignment"
|
GOAL_ALIGNMENT = "goal_alignment"
|
||||||
SEMANTIC_QUALITY = "semantic_quality"
|
SEMANTIC_QUALITY = "semantic_quality"
|
||||||
@@ -19,7 +20,7 @@ class MetricCategory(enum.Enum):
|
|||||||
TOOL_INVOCATION = "tool_invocation"
|
TOOL_INVOCATION = "tool_invocation"
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return self.value.replace('_', ' ').title()
|
return self.value.replace("_", " ").title()
|
||||||
|
|
||||||
|
|
||||||
class EvaluationScore(BaseModel):
|
class EvaluationScore(BaseModel):
|
||||||
@@ -27,15 +28,13 @@ class EvaluationScore(BaseModel):
|
|||||||
default=5.0,
|
default=5.0,
|
||||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=10.0
|
le=10.0,
|
||||||
)
|
)
|
||||||
feedback: str = Field(
|
feedback: str = Field(
|
||||||
default="",
|
default="", description="Detailed feedback explaining the evaluation score"
|
||||||
description="Detailed feedback explaining the evaluation score"
|
|
||||||
)
|
)
|
||||||
raw_response: str | None = Field(
|
raw_response: str | None = Field(
|
||||||
default=None,
|
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||||
description="Raw response from the evaluator (e.g., LLM)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@@ -57,7 +56,7 @@ class BaseEvaluator(abc.ABC):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: Any,
|
final_output: Any,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -67,9 +66,8 @@ class BaseEvaluator(abc.ABC):
|
|||||||
class AgentEvaluationResult(BaseModel):
|
class AgentEvaluationResult(BaseModel):
|
||||||
agent_id: str = Field(description="ID of the evaluated agent")
|
agent_id: str = Field(description="ID of the evaluated agent")
|
||||||
task_id: str = Field(description="ID of the task that was executed")
|
task_id: str = Field(description="ID of the task that was executed")
|
||||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict, description="Evaluation scores for each metric category"
|
||||||
description="Evaluation scores for each metric category"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,33 +79,23 @@ class AggregationStrategy(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class AgentAggregatedEvaluationResult(BaseModel):
|
class AgentAggregatedEvaluationResult(BaseModel):
|
||||||
agent_id: str = Field(
|
agent_id: str = Field(default="", description="ID of the agent")
|
||||||
default="",
|
agent_role: str = Field(default="", description="Role of the agent")
|
||||||
description="ID of the agent"
|
|
||||||
)
|
|
||||||
agent_role: str = Field(
|
|
||||||
default="",
|
|
||||||
description="Role of the agent"
|
|
||||||
)
|
|
||||||
task_count: int = Field(
|
task_count: int = Field(
|
||||||
default=0,
|
default=0, description="Number of tasks included in this aggregation"
|
||||||
description="Number of tasks included in this aggregation"
|
|
||||||
)
|
)
|
||||||
aggregation_strategy: AggregationStrategy = Field(
|
aggregation_strategy: AggregationStrategy = Field(
|
||||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||||
description="Strategy used for aggregation"
|
description="Strategy used for aggregation",
|
||||||
)
|
)
|
||||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||||
description="Aggregated metrics across all tasks"
|
|
||||||
)
|
)
|
||||||
task_results: List[str] = Field(
|
task_results: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||||
description="IDs of tasks included in this aggregation"
|
|
||||||
)
|
)
|
||||||
overall_score: Optional[float] = Field(
|
overall_score: float | None = Field(
|
||||||
default=None,
|
default=None, description="Overall score for this agent"
|
||||||
description="Overall score for this agent"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@@ -119,7 +107,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
|||||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||||
|
|
||||||
if score.feedback:
|
if score.feedback:
|
||||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||||
result += f" {detailed_feedback}\n"
|
result += f" {detailed_feedback}\n"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, Any, List
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.box import HEAVY_EDGE, ROUNDED
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from rich.box import HEAVY_EDGE, ROUNDED
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
|
from crewai.experimental.evaluation import EvaluationScore
|
||||||
from crewai.experimental.evaluation.base_evaluator import (
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
AgentAggregatedEvaluationResult,
|
AgentAggregatedEvaluationResult,
|
||||||
AggregationStrategy,
|
|
||||||
AgentEvaluationResult,
|
AgentEvaluationResult,
|
||||||
|
AggregationStrategy,
|
||||||
MetricCategory,
|
MetricCategory,
|
||||||
)
|
)
|
||||||
from crewai.experimental.evaluation import EvaluationScore
|
|
||||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
|
|||||||
self.console_formatter = ConsoleFormatter()
|
self.console_formatter = ConsoleFormatter()
|
||||||
|
|
||||||
def display_evaluation_with_feedback(
|
def display_evaluation_with_feedback(
|
||||||
self, iterations_results: Dict[int, Dict[str, List[Any]]]
|
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||||
):
|
):
|
||||||
if not iterations_results:
|
if not iterations_results:
|
||||||
self.console_formatter.print(
|
self.console_formatter.print(
|
||||||
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
|
|||||||
|
|
||||||
def display_summary_results(
|
def display_summary_results(
|
||||||
self,
|
self,
|
||||||
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
|
iterations_results: dict[int, dict[str, list[AgentAggregatedEvaluationResult]]],
|
||||||
):
|
):
|
||||||
if not iterations_results:
|
if not iterations_results:
|
||||||
self.console_formatter.print(
|
self.console_formatter.print(
|
||||||
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
|
|||||||
self,
|
self,
|
||||||
agent_role: str,
|
agent_role: str,
|
||||||
metric: str,
|
metric: str,
|
||||||
feedbacks: List[str],
|
feedbacks: list[str],
|
||||||
scores: List[float | None],
|
scores: list[float | None],
|
||||||
strategy: AggregationStrategy,
|
strategy: AggregationStrategy,
|
||||||
) -> str:
|
) -> str:
|
||||||
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
||||||
return "\n\n".join(
|
return "\n\n".join(
|
||||||
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
|
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = create_llm()
|
llm = create_llm()
|
||||||
|
|
||||||
formatted_feedbacks = []
|
formatted_feedbacks = []
|
||||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
|
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
|
||||||
if len(feedback) > 500:
|
if len(feedback) > 500:
|
||||||
feedback = feedback[:500] + "..."
|
feedback = feedback[:500] + "..."
|
||||||
score_text = f"{score:.1f}" if score is not None else "N/A"
|
score_text = f"{score:.1f}" if score is not None else "N/A"
|
||||||
formatted_feedbacks.append(
|
formatted_feedbacks.append(
|
||||||
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
|
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
|
||||||
)
|
)
|
||||||
|
|
||||||
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
||||||
@@ -366,9 +368,7 @@ class EvaluationDisplayFormatter:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
assert llm is not None
|
assert llm is not None
|
||||||
response = llm.call(prompt)
|
return llm.call(prompt)
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return "Synthesized from multiple tasks: " + "\n\n".join(
|
return "Synthesized from multiple tasks: " + "\n\n".join(
|
||||||
|
|||||||
@@ -1,26 +1,25 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
from crewai.events.event_bus import CrewAIEventsBus
|
from crewai.events.event_bus import CrewAIEventsBus
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionStartedEvent,
|
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
LiteAgentExecutionStartedEvent,
|
AgentExecutionStartedEvent,
|
||||||
LiteAgentExecutionCompletedEvent,
|
LiteAgentExecutionCompletedEvent,
|
||||||
|
LiteAgentExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
|
||||||
from crewai.events.types.tool_usage_events import (
|
from crewai.events.types.tool_usage_events import (
|
||||||
ToolUsageFinishedEvent,
|
|
||||||
ToolUsageErrorEvent,
|
|
||||||
ToolExecutionErrorEvent,
|
ToolExecutionErrorEvent,
|
||||||
ToolSelectionErrorEvent,
|
ToolSelectionErrorEvent,
|
||||||
|
ToolUsageErrorEvent,
|
||||||
|
ToolUsageFinishedEvent,
|
||||||
ToolValidateInputErrorEvent,
|
ToolValidateInputErrorEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
class EvaluationTraceCallback(BaseEventListener):
|
class EvaluationTraceCallback(BaseEventListener):
|
||||||
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
if hasattr(self, "current_llm_call"):
|
if hasattr(self, "current_llm_call"):
|
||||||
self.current_llm_call = {}
|
self.current_llm_call = {}
|
||||||
|
|
||||||
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
|
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||||
trace_key = f"{agent_id}_{task_id}"
|
trace_key = f"{agent_id}_{task_id}"
|
||||||
return self.traces.get(trace_key)
|
return self.traces.get(trace_key)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
|
from crewai.experimental.evaluation.experiment.result import (
|
||||||
|
ExperimentResult,
|
||||||
|
ExperimentResults,
|
||||||
|
)
|
||||||
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
||||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]
|
||||||
"ExperimentRunner",
|
|
||||||
"ExperimentResults",
|
|
||||||
"ExperimentResult"
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ import json
|
|||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ExperimentResult(BaseModel):
|
class ExperimentResult(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
inputs: dict[str, Any]
|
inputs: dict[str, Any]
|
||||||
@@ -12,35 +14,48 @@ class ExperimentResult(BaseModel):
|
|||||||
passed: bool
|
passed: bool
|
||||||
agent_evaluations: dict[str, Any] | None = None
|
agent_evaluations: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ExperimentResults:
|
class ExperimentResults:
|
||||||
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
|
def __init__(
|
||||||
|
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
|
||||||
|
):
|
||||||
self.results = results
|
self.results = results
|
||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.timestamp = datetime.now(timezone.utc)
|
self.timestamp = datetime.now(timezone.utc)
|
||||||
|
|
||||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
from crewai.experimental.evaluation.experiment.result_display import (
|
||||||
|
ExperimentResultsDisplay,
|
||||||
|
)
|
||||||
|
|
||||||
self.display = ExperimentResultsDisplay()
|
self.display = ExperimentResultsDisplay()
|
||||||
|
|
||||||
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
||||||
data = {
|
data = {
|
||||||
"timestamp": self.timestamp.isoformat(),
|
"timestamp": self.timestamp.isoformat(),
|
||||||
"metadata": self.metadata,
|
"metadata": self.metadata,
|
||||||
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
|
"results": [
|
||||||
|
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
if filepath:
|
if filepath:
|
||||||
with open(filepath, 'w') as f:
|
with open(filepath, "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
|
def compare_with_baseline(
|
||||||
|
self,
|
||||||
|
baseline_filepath: str,
|
||||||
|
save_current: bool = True,
|
||||||
|
print_summary: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
baseline_runs = []
|
baseline_runs = []
|
||||||
|
|
||||||
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
||||||
try:
|
try:
|
||||||
with open(baseline_filepath, 'r') as f:
|
with open(baseline_filepath, "r") as f:
|
||||||
baseline_data = json.load(f)
|
baseline_data = json.load(f)
|
||||||
|
|
||||||
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
|
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
|
||||||
@@ -48,14 +63,18 @@ class ExperimentResults:
|
|||||||
elif isinstance(baseline_data, list):
|
elif isinstance(baseline_data, list):
|
||||||
baseline_runs = baseline_data
|
baseline_runs = baseline_data
|
||||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||||
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
self.display.console.print(
|
||||||
|
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
if not baseline_runs:
|
if not baseline_runs:
|
||||||
if save_current:
|
if save_current:
|
||||||
current_data = self.to_json()
|
current_data = self.to_json()
|
||||||
with open(baseline_filepath, 'w') as f:
|
with open(baseline_filepath, "w") as f:
|
||||||
json.dump([current_data], f, indent=2)
|
json.dump([current_data], f, indent=2)
|
||||||
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
self.display.console.print(
|
||||||
|
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
|
||||||
|
)
|
||||||
return {"is_baseline": True, "changes": {}}
|
return {"is_baseline": True, "changes": {}}
|
||||||
|
|
||||||
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||||
@@ -69,9 +88,11 @@ class ExperimentResults:
|
|||||||
if save_current:
|
if save_current:
|
||||||
current_data = self.to_json()
|
current_data = self.to_json()
|
||||||
baseline_runs.append(current_data)
|
baseline_runs.append(current_data)
|
||||||
with open(baseline_filepath, 'w') as f:
|
with open(baseline_filepath, "w") as f:
|
||||||
json.dump(baseline_runs, f, indent=2)
|
json.dump(baseline_runs, f, indent=2)
|
||||||
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
self.display.console.print(
|
||||||
|
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
|
||||||
|
)
|
||||||
|
|
||||||
return comparison
|
return comparison
|
||||||
|
|
||||||
@@ -118,5 +139,5 @@ class ExperimentResults:
|
|||||||
"new_tests": new_tests,
|
"new_tests": new_tests,
|
||||||
"missing_tests": missing_tests,
|
"missing_tests": missing_tests,
|
||||||
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
||||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
|
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||||
|
|
||||||
|
|
||||||
class ExperimentResultsDisplay:
|
class ExperimentResultsDisplay:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
|
|||||||
table.add_row("Total Test Cases", str(total))
|
table.add_row("Total Test Cases", str(total))
|
||||||
table.add_row("Passed", str(passed))
|
table.add_row("Passed", str(passed))
|
||||||
table.add_row("Failed", str(total - passed))
|
table.add_row("Failed", str(total - passed))
|
||||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
table.add_row(
|
||||||
|
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||||
|
)
|
||||||
|
|
||||||
self.console.print(table)
|
self.console.print(table)
|
||||||
|
|
||||||
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
self.console.print(
|
||||||
expand=False))
|
Panel(
|
||||||
|
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||||
|
expand=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
table = Table(title="Results Comparison")
|
table = Table(title="Results Comparison")
|
||||||
table.add_column("Metric", style="cyan")
|
table.add_column("Metric", style="cyan")
|
||||||
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
|
|||||||
|
|
||||||
improved = comparison.get("improved", [])
|
improved = comparison.get("improved", [])
|
||||||
if improved:
|
if improved:
|
||||||
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
|
details = ", ".join(
|
||||||
|
[f"{test_identifier}" for test_identifier in improved[:3]]
|
||||||
|
)
|
||||||
if len(improved) > 3:
|
if len(improved) > 3:
|
||||||
details += f" and {len(improved) - 3} more"
|
details += f" and {len(improved) - 3} more"
|
||||||
table.add_row("✅ Improved", str(len(improved)), details)
|
table.add_row("✅ Improved", str(len(improved)), details)
|
||||||
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
|
|||||||
|
|
||||||
regressed = comparison.get("regressed", [])
|
regressed = comparison.get("regressed", [])
|
||||||
if regressed:
|
if regressed:
|
||||||
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
|
details = ", ".join(
|
||||||
|
[f"{test_identifier}" for test_identifier in regressed[:3]]
|
||||||
|
)
|
||||||
if len(regressed) > 3:
|
if len(regressed) > 3:
|
||||||
details += f" and {len(regressed) - 3} more"
|
details += f" and {len(regressed) - 3} more"
|
||||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||||
|
|||||||
@@ -2,11 +2,19 @@ from collections import defaultdict
|
|||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai import Crew, Agent
|
from crewai import Agent, Crew
|
||||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
from crewai.experimental.evaluation.evaluation_display import (
|
||||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
AgentAggregatedEvaluationResult,
|
||||||
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
)
|
||||||
|
from crewai.experimental.evaluation.experiment.result import (
|
||||||
|
ExperimentResult,
|
||||||
|
ExperimentResults,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.experiment.result_display import (
|
||||||
|
ExperimentResultsDisplay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExperimentRunner:
|
class ExperimentRunner:
|
||||||
def __init__(self, dataset: list[dict[str, Any]]):
|
def __init__(self, dataset: list[dict[str, Any]]):
|
||||||
@@ -14,7 +22,12 @@ class ExperimentRunner:
|
|||||||
self.evaluator: AgentEvaluator | None = None
|
self.evaluator: AgentEvaluator | None = None
|
||||||
self.display = ExperimentResultsDisplay()
|
self.display = ExperimentResultsDisplay()
|
||||||
|
|
||||||
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
|
def run(
|
||||||
|
self,
|
||||||
|
crew: Crew | None = None,
|
||||||
|
agents: list[Agent] | None = None,
|
||||||
|
print_summary: bool = False,
|
||||||
|
) -> ExperimentResults:
|
||||||
if crew and not agents:
|
if crew and not agents:
|
||||||
agents = crew.agents
|
agents = crew.agents
|
||||||
|
|
||||||
@@ -35,13 +48,20 @@ class ExperimentRunner:
|
|||||||
|
|
||||||
return experiment_results
|
return experiment_results
|
||||||
|
|
||||||
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
|
def _run_test_case(
|
||||||
|
self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None
|
||||||
|
) -> ExperimentResult:
|
||||||
inputs = test_case["inputs"]
|
inputs = test_case["inputs"]
|
||||||
expected_score = test_case["expected_score"]
|
expected_score = test_case["expected_score"]
|
||||||
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
identifier = (
|
||||||
|
test_case.get("identifier")
|
||||||
|
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
self.display.console.print(
|
||||||
|
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
|
||||||
|
)
|
||||||
self.display.console.print("\n")
|
self.display.console.print("\n")
|
||||||
if crew:
|
if crew:
|
||||||
crew.kickoff(inputs=inputs)
|
crew.kickoff(inputs=inputs)
|
||||||
@@ -61,35 +81,38 @@ class ExperimentRunner:
|
|||||||
score=actual_score,
|
score=actual_score,
|
||||||
expected_score=expected_score,
|
expected_score=expected_score,
|
||||||
passed=passed,
|
passed=passed,
|
||||||
agent_evaluations=agent_evaluations
|
agent_evaluations=agent_evaluations,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
|
||||||
return ExperimentResult(
|
return ExperimentResult(
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
score=0,
|
score=0,
|
||||||
expected_score=expected_score,
|
expected_score=expected_score,
|
||||||
passed=False
|
passed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
|
def _extract_scores(
|
||||||
|
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
|
||||||
|
) -> float | dict[str, float]:
|
||||||
all_scores: dict[str, list[float]] = defaultdict(list)
|
all_scores: dict[str, list[float]] = defaultdict(list)
|
||||||
for evaluation in agent_evaluations.values():
|
for evaluation in agent_evaluations.values():
|
||||||
for metric_name, score in evaluation.metrics.items():
|
for metric_name, score in evaluation.metrics.items():
|
||||||
if score.score is not None:
|
if score.score is not None:
|
||||||
all_scores[metric_name.value].append(score.score)
|
all_scores[metric_name.value].append(score.score)
|
||||||
|
|
||||||
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
|
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
|
||||||
|
|
||||||
if len(avg_scores) == 1:
|
if len(avg_scores) == 1:
|
||||||
return list(avg_scores.values())[0]
|
return next(iter(avg_scores.values()))
|
||||||
|
|
||||||
return avg_scores
|
return avg_scores
|
||||||
|
|
||||||
def _assert_scores(self, expected: float | dict[str, float],
|
def _assert_scores(
|
||||||
actual: float | dict[str, float]) -> bool:
|
self, expected: float | dict[str, float], actual: float | dict[str, float]
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Compare expected and actual scores, and return whether the test case passed.
|
Compare expected and actual scores, and return whether the test case passed.
|
||||||
|
|
||||||
@@ -122,4 +145,4 @@ class ExperimentRunner:
|
|||||||
# All matching keys must have actual >= expected
|
# All matching keys must have actual >= expected
|
||||||
return all(actual[key] >= expected[key] for key in matching_keys)
|
return all(actual[key] >= expected[key] for key in matching_keys)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,26 +1,21 @@
|
|||||||
|
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
|
||||||
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
||||||
ReasoningEfficiencyEvaluator
|
ReasoningEfficiencyEvaluator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
|
||||||
ToolSelectionEvaluator,
|
|
||||||
ParameterExtractionEvaluator,
|
|
||||||
ToolInvocationEvaluator
|
|
||||||
)
|
|
||||||
|
|
||||||
from crewai.experimental.evaluation.metrics.goal_metrics import (
|
|
||||||
GoalAlignmentEvaluator
|
|
||||||
)
|
|
||||||
|
|
||||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
||||||
SemanticQualityEvaluator
|
SemanticQualityEvaluator,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||||
|
ParameterExtractionEvaluator,
|
||||||
|
ToolInvocationEvaluator,
|
||||||
|
ToolSelectionEvaluator,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ReasoningEfficiencyEvaluator",
|
|
||||||
"ToolSelectionEvaluator",
|
|
||||||
"ParameterExtractionEvaluator",
|
|
||||||
"ToolInvocationEvaluator",
|
|
||||||
"GoalAlignmentEvaluator",
|
"GoalAlignmentEvaluator",
|
||||||
"SemanticQualityEvaluator"
|
"ParameterExtractionEvaluator",
|
||||||
]
|
"ReasoningEfficiencyEvaluator",
|
||||||
|
"SemanticQualityEvaluator",
|
||||||
|
"ToolInvocationEvaluator",
|
||||||
|
"ToolSelectionEvaluator",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
|
BaseEvaluator,
|
||||||
|
EvaluationScore,
|
||||||
|
MetricCategory,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
|
||||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
|
||||||
|
|
||||||
class GoalAlignmentEvaluator(BaseEvaluator):
|
class GoalAlignmentEvaluator(BaseEvaluator):
|
||||||
@property
|
@property
|
||||||
@@ -14,7 +18,7 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: Any,
|
final_output: Any,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -23,7 +27,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
|||||||
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
|
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
|
||||||
|
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||||
|
|
||||||
Score the agent's goal alignment on a scale from 0-10 where:
|
Score the agent's goal alignment on a scale from 0-10 where:
|
||||||
- 0: Complete misalignment, agent did not understand or attempt the task goal
|
- 0: Complete misalignment, agent did not understand or attempt the task goal
|
||||||
@@ -37,8 +43,11 @@ Consider:
|
|||||||
4. Did the agent provide all requested information or deliverables?
|
4. Did the agent provide all requested information or deliverables?
|
||||||
|
|
||||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||||
"""},
|
""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
Agent goal: {agent.goal}
|
Agent goal: {agent.goal}
|
||||||
{task_context}
|
{task_context}
|
||||||
@@ -47,7 +56,8 @@ Agent's final output:
|
|||||||
{final_output}
|
{final_output}
|
||||||
|
|
||||||
Evaluate how well the agent's output aligns with the assigned task goal.
|
Evaluate how well the agent's output aligns with the assigned task goal.
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
response = self.llm.call(prompt)
|
response = self.llm.call(prompt)
|
||||||
@@ -59,11 +69,11 @@ Evaluate how well the agent's output aligns with the assigned task goal.
|
|||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=evaluation_data.get("score", 0),
|
score=evaluation_data.get("score", 0),
|
||||||
feedback=evaluation_data.get("feedback", response),
|
feedback=evaluation_data.get("feedback", response),
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,18 +8,23 @@ This module provides evaluator implementations for:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
import numpy as np
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.task import Task
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
|
BaseEvaluator,
|
||||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
EvaluationScore,
|
||||||
|
MetricCategory,
|
||||||
|
)
|
||||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||||
|
from crewai.task import Task
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
class ReasoningPatternType(Enum):
|
class ReasoningPatternType(Enum):
|
||||||
EFFICIENT = "efficient" # Good reasoning flow
|
EFFICIENT = "efficient" # Good reasoning flow
|
||||||
LOOP = "loop" # Agent is stuck in a loop
|
LOOP = "loop" # Agent is stuck in a loop
|
||||||
@@ -36,7 +41,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: TaskOutput | str,
|
final_output: TaskOutput | str,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -49,7 +54,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
|||||||
if not llm_calls or len(llm_calls) < 2:
|
if not llm_calls or len(llm_calls) < 2:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
|
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
|
||||||
)
|
)
|
||||||
|
|
||||||
total_calls = len(llm_calls)
|
total_calls = len(llm_calls)
|
||||||
@@ -58,12 +63,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
|||||||
time_intervals = []
|
time_intervals = []
|
||||||
has_reliable_timing = True
|
has_reliable_timing = True
|
||||||
for i in range(1, len(llm_calls)):
|
for i in range(1, len(llm_calls)):
|
||||||
start_time = llm_calls[i-1].get("end_time")
|
start_time = llm_calls[i - 1].get("end_time")
|
||||||
end_time = llm_calls[i].get("start_time")
|
end_time = llm_calls[i].get("start_time")
|
||||||
if start_time and end_time and start_time != end_time:
|
if start_time and end_time and start_time != end_time:
|
||||||
try:
|
try:
|
||||||
interval = end_time - start_time
|
interval = end_time - start_time
|
||||||
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
|
time_intervals.append(
|
||||||
|
interval.total_seconds()
|
||||||
|
if hasattr(interval, "total_seconds")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
has_reliable_timing = False
|
has_reliable_timing = False
|
||||||
else:
|
else:
|
||||||
@@ -83,14 +92,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
|||||||
if has_reliable_timing and time_intervals:
|
if has_reliable_timing and time_intervals:
|
||||||
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
|
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
|
||||||
|
|
||||||
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
|
loop_info = (
|
||||||
|
f"Detected {len(loop_details)} potential reasoning loops."
|
||||||
|
if loop_detected
|
||||||
|
else "No significant reasoning loops detected."
|
||||||
|
)
|
||||||
|
|
||||||
call_samples = self._get_call_samples(llm_calls)
|
call_samples = self._get_call_samples(llm_calls)
|
||||||
|
|
||||||
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
final_output = (
|
||||||
|
final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||||
|
)
|
||||||
|
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||||
|
|
||||||
Evaluate the agent's reasoning efficiency across these five key subcategories:
|
Evaluate the agent's reasoning efficiency across these five key subcategories:
|
||||||
|
|
||||||
@@ -120,8 +137,11 @@ Return your evaluation as JSON with the following structure:
|
|||||||
"feedback": string (general feedback about overall reasoning efficiency),
|
"feedback": string (general feedback about overall reasoning efficiency),
|
||||||
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
|
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
|
||||||
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
|
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
|
||||||
}"""},
|
}""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
{task_context}
|
{task_context}
|
||||||
|
|
||||||
@@ -140,7 +160,8 @@ Agent's final output:
|
|||||||
|
|
||||||
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
|
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
|
||||||
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
|
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
@@ -156,34 +177,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
conciseness = scores.get("conciseness", 5.0)
|
conciseness = scores.get("conciseness", 5.0)
|
||||||
loop_avoidance = scores.get("loop_avoidance", 5.0)
|
loop_avoidance = scores.get("loop_avoidance", 5.0)
|
||||||
|
|
||||||
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
|
overall_score = evaluation_data.get(
|
||||||
|
"overall_score", evaluation_data.get("score", 5.0)
|
||||||
|
)
|
||||||
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
|
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
|
||||||
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
|
optimization_suggestions = evaluation_data.get(
|
||||||
|
"optimization_suggestions", "No specific suggestions provided."
|
||||||
|
)
|
||||||
|
|
||||||
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
|
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
|
||||||
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
detailed_feedback += (
|
||||||
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
|
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||||
|
)
|
||||||
|
detailed_feedback += (
|
||||||
|
f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||||
|
)
|
||||||
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
|
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
|
||||||
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
detailed_feedback += (
|
||||||
|
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||||
|
)
|
||||||
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
|
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
|
||||||
|
|
||||||
detailed_feedback += f"Feedback:\n{feedback}\n\n"
|
detailed_feedback += f"Feedback:\n{feedback}\n\n"
|
||||||
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
|
detailed_feedback += (
|
||||||
|
f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||||
|
)
|
||||||
|
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=float(overall_score),
|
score=float(overall_score),
|
||||||
feedback=detailed_feedback,
|
feedback=detailed_feedback,
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
|
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
|
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
|
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||||
loop_details = []
|
loop_details = []
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
@@ -205,18 +238,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
# A more sophisticated approach would use semantic similarity
|
# A more sophisticated approach would use semantic similarity
|
||||||
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
||||||
if similarity > 0.7: # Arbitrary threshold
|
if similarity > 0.7: # Arbitrary threshold
|
||||||
loop_details.append({
|
loop_details.append(
|
||||||
"first_occurrence": i,
|
{
|
||||||
"second_occurrence": j,
|
"first_occurrence": i,
|
||||||
"similarity": similarity,
|
"second_occurrence": j,
|
||||||
"snippet": messages[i][:100] + "..."
|
"similarity": similarity,
|
||||||
})
|
"snippet": messages[i][:100] + "...",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return len(loop_details) > 0, loop_details
|
return len(loop_details) > 0, loop_details
|
||||||
|
|
||||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||||
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
|
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
|
||||||
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
|
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
|
||||||
|
|
||||||
# Simple Jaccard similarity on word sets
|
# Simple Jaccard similarity on word sets
|
||||||
words1 = set(text1.split())
|
words1 = set(text1.split())
|
||||||
@@ -227,7 +262,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
|
|
||||||
return intersection / union if union > 0 else 0.0
|
return intersection / union if union > 0 else 0.0
|
||||||
|
|
||||||
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
|
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||||
call_lengths = []
|
call_lengths = []
|
||||||
response_times = []
|
response_times = []
|
||||||
|
|
||||||
@@ -267,7 +302,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
details = "Agent is consistently verbose across interactions."
|
details = "Agent is consistently verbose across interactions."
|
||||||
elif len(llm_calls) > 10 and length_trend > 0.5:
|
elif len(llm_calls) > 10 and length_trend > 0.5:
|
||||||
primary_pattern = ReasoningPatternType.INDECISIVE
|
primary_pattern = ReasoningPatternType.INDECISIVE
|
||||||
details = "Agent shows signs of indecisiveness with increasing message lengths."
|
details = (
|
||||||
|
"Agent shows signs of indecisiveness with increasing message lengths."
|
||||||
|
)
|
||||||
elif std_length / avg_length > 0.8:
|
elif std_length / avg_length > 0.8:
|
||||||
primary_pattern = ReasoningPatternType.SCATTERED
|
primary_pattern = ReasoningPatternType.SCATTERED
|
||||||
details = "Agent shows inconsistent reasoning flow with highly variable responses."
|
details = "Agent shows inconsistent reasoning flow with highly variable responses."
|
||||||
@@ -279,8 +316,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
"avg_length": avg_length,
|
"avg_length": avg_length,
|
||||||
"std_length": std_length,
|
"std_length": std_length,
|
||||||
"length_trend": length_trend,
|
"length_trend": length_trend,
|
||||||
"loop_score": loop_score
|
"loop_score": loop_score,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calculate_trend(self, values: Sequence[float | int]) -> float:
|
def _calculate_trend(self, values: Sequence[float | int]) -> float:
|
||||||
@@ -303,7 +340,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
except Exception:
|
except Exception:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
|
def _calculate_loop_likelihood(
|
||||||
|
self, call_lengths: Sequence[float], response_times: Sequence[float]
|
||||||
|
) -> float:
|
||||||
if not call_lengths or len(call_lengths) < 3:
|
if not call_lengths or len(call_lengths) < 3:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@@ -312,7 +351,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
if len(call_lengths) >= 4:
|
if len(call_lengths) >= 4:
|
||||||
repeated_lengths = 0
|
repeated_lengths = 0
|
||||||
for i in range(len(call_lengths) - 2):
|
for i in range(len(call_lengths) - 2):
|
||||||
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
|
ratio = (
|
||||||
|
call_lengths[i] / call_lengths[i + 2]
|
||||||
|
if call_lengths[i + 2] > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
if 0.85 <= ratio <= 1.15:
|
if 0.85 <= ratio <= 1.15:
|
||||||
repeated_lengths += 1
|
repeated_lengths += 1
|
||||||
|
|
||||||
@@ -331,14 +374,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
|
|
||||||
return np.mean(indicators) if indicators else 0.0
|
return np.mean(indicators) if indicators else 0.0
|
||||||
|
|
||||||
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
|
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||||
samples = []
|
samples = []
|
||||||
|
|
||||||
if len(llm_calls) <= 6:
|
if len(llm_calls) <= 6:
|
||||||
sample_indices = list(range(len(llm_calls)))
|
sample_indices = list(range(len(llm_calls)))
|
||||||
else:
|
else:
|
||||||
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
|
sample_indices = [
|
||||||
len(llm_calls) - 2, len(llm_calls) - 1]
|
0,
|
||||||
|
1,
|
||||||
|
len(llm_calls) // 2 - 1,
|
||||||
|
len(llm_calls) // 2,
|
||||||
|
len(llm_calls) - 2,
|
||||||
|
len(llm_calls) - 1,
|
||||||
|
]
|
||||||
|
|
||||||
for idx in sample_indices:
|
for idx in sample_indices:
|
||||||
call = llm_calls[idx]
|
call = llm_calls[idx]
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
|
BaseEvaluator,
|
||||||
|
EvaluationScore,
|
||||||
|
MetricCategory,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
|
||||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
|
||||||
|
|
||||||
class SemanticQualityEvaluator(BaseEvaluator):
|
class SemanticQualityEvaluator(BaseEvaluator):
|
||||||
@property
|
@property
|
||||||
@@ -14,7 +18,7 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: Any,
|
final_output: Any,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -22,7 +26,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
task_context = f"Task description: {task.description}"
|
task_context = f"Task description: {task.description}"
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||||
|
|
||||||
Score the semantic quality on a scale from 0-10 where:
|
Score the semantic quality on a scale from 0-10 where:
|
||||||
- 0: Completely incoherent, confusing, or logically flawed output
|
- 0: Completely incoherent, confusing, or logically flawed output
|
||||||
@@ -37,8 +43,11 @@ Consider:
|
|||||||
5. Is the output free from contradictions and logical fallacies?
|
5. Is the output free from contradictions and logical fallacies?
|
||||||
|
|
||||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||||
"""},
|
""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
{task_context}
|
{task_context}
|
||||||
|
|
||||||
@@ -46,7 +55,8 @@ Agent's final output:
|
|||||||
{final_output}
|
{final_output}
|
||||||
|
|
||||||
Evaluate the semantic quality and reasoning of this output.
|
Evaluate the semantic quality and reasoning of this output.
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
@@ -56,13 +66,15 @@ Evaluate the semantic quality and reasoning of this output.
|
|||||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||||
assert evaluation_data is not None
|
assert evaluation_data is not None
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
|
score=float(evaluation_data["score"])
|
||||||
|
if evaluation_data.get("score") is not None
|
||||||
|
else None,
|
||||||
feedback=evaluation_data.get("feedback", response),
|
feedback=evaluation_data.get("feedback", response),
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
|
||||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
|
BaseEvaluator,
|
||||||
|
EvaluationScore,
|
||||||
|
MetricCategory,
|
||||||
|
)
|
||||||
|
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
class ToolSelectionEvaluator(BaseEvaluator):
|
class ToolSelectionEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_category(self) -> MetricCategory:
|
def metric_category(self) -> MetricCategory:
|
||||||
return MetricCategory.TOOL_SELECTION
|
return MetricCategory.TOOL_SELECTION
|
||||||
@@ -16,7 +19,7 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: str,
|
final_output: str,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -26,19 +29,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
|||||||
|
|
||||||
tool_uses = execution_trace.get("tool_uses", [])
|
tool_uses = execution_trace.get("tool_uses", [])
|
||||||
tool_count = len(tool_uses)
|
tool_count = len(tool_uses)
|
||||||
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
|
unique_tool_types = set(
|
||||||
|
[tool.get("tool", "Unknown tool") for tool in tool_uses]
|
||||||
|
)
|
||||||
|
|
||||||
if tool_count == 0:
|
if tool_count == 0:
|
||||||
if not agent.tools:
|
if not agent.tools:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None, feedback="Agent had no tools available to use."
|
||||||
feedback="Agent had no tools available to use."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return EvaluationScore(
|
|
||||||
score=None,
|
|
||||||
feedback="Agent had tools available but didn't use any."
|
|
||||||
)
|
)
|
||||||
|
return EvaluationScore(
|
||||||
|
score=None, feedback="Agent had tools available but didn't use any."
|
||||||
|
)
|
||||||
|
|
||||||
available_tools_info = ""
|
available_tools_info = ""
|
||||||
if agent.tools:
|
if agent.tools:
|
||||||
@@ -52,7 +54,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
|||||||
tool_types_summary += f"- {tool_type}\n"
|
tool_types_summary += f"- {tool_type}\n"
|
||||||
|
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||||
|
|
||||||
You must evaluate based on these 2 criteria:
|
You must evaluate based on these 2 criteria:
|
||||||
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
|
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
|
||||||
@@ -73,8 +77,11 @@ Return your evaluation as JSON with these fields:
|
|||||||
- overall_score: number (average of all scores, 0-10)
|
- overall_score: number (average of all scores, 0-10)
|
||||||
- feedback: string (focused ONLY on tool selection decisions from available tools)
|
- feedback: string (focused ONLY on tool selection decisions from available tools)
|
||||||
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
|
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
|
||||||
"""},
|
""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
{task_context}
|
{task_context}
|
||||||
|
|
||||||
@@ -89,7 +96,8 @@ IMPORTANT:
|
|||||||
- ONLY evaluate selection from tools listed as available
|
- ONLY evaluate selection from tools listed as available
|
||||||
- DO NOT suggest new tools that aren't in the available tools list
|
- DO NOT suggest new tools that aren't in the available tools list
|
||||||
- DO NOT evaluate tool usage or results
|
- DO NOT evaluate tool usage or results
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
response = self.llm.call(prompt)
|
response = self.llm.call(prompt)
|
||||||
@@ -105,22 +113,24 @@ IMPORTANT:
|
|||||||
|
|
||||||
feedback = "Tool Selection Evaluation:\n"
|
feedback = "Tool Selection Evaluation:\n"
|
||||||
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
|
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
|
||||||
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
feedback += (
|
||||||
|
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||||
|
)
|
||||||
if "improvement_suggestions" in evaluation_data:
|
if "improvement_suggestions" in evaluation_data:
|
||||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||||
else:
|
else:
|
||||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
feedback += evaluation_data.get(
|
||||||
|
"feedback", "No detailed feedback available."
|
||||||
|
)
|
||||||
|
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=overall_score,
|
score=overall_score, feedback=feedback, raw_response=response
|
||||||
feedback=feedback,
|
|
||||||
raw_response=response
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Error evaluating tool selection: {e}",
|
feedback=f"Error evaluating tool selection: {e}",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -132,7 +142,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: str,
|
final_output: str,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -145,19 +155,26 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
|||||||
if tool_count == 0:
|
if tool_count == 0:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback="No tool usage detected. Cannot evaluate parameter extraction."
|
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
|
||||||
)
|
)
|
||||||
|
|
||||||
validation_errors = []
|
validation_errors = []
|
||||||
for tool_use in tool_uses:
|
for tool_use in tool_uses:
|
||||||
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
|
if (
|
||||||
validation_errors.append({
|
not tool_use.get("success", True)
|
||||||
"tool": tool_use.get("tool", "Unknown tool"),
|
and tool_use.get("error_type") == "validation_error"
|
||||||
"error": tool_use.get("result"),
|
):
|
||||||
"args": tool_use.get("args", {})
|
validation_errors.append(
|
||||||
})
|
{
|
||||||
|
"tool": tool_use.get("tool", "Unknown tool"),
|
||||||
|
"error": tool_use.get("result"),
|
||||||
|
"args": tool_use.get("args", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
|
validation_error_rate = (
|
||||||
|
len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
param_samples = []
|
param_samples = []
|
||||||
for i, tool_use in enumerate(tool_uses[:5]):
|
for i, tool_use in enumerate(tool_uses[:5]):
|
||||||
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
|||||||
|
|
||||||
is_validation_error = error_type == "validation_error"
|
is_validation_error = error_type == "validation_error"
|
||||||
|
|
||||||
sample = f"Tool use #{i+1} - {tool_name}:\n"
|
sample = f"Tool use #{i + 1} - {tool_name}:\n"
|
||||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||||
sample += f"- Success: {'No' if not success else 'Yes'}"
|
sample += f"- Success: {'No' if not success else 'Yes'}"
|
||||||
|
|
||||||
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
|||||||
tool_name = err.get("tool", "Unknown tool")
|
tool_name = err.get("tool", "Unknown tool")
|
||||||
error_msg = err.get("error", "Unknown error")
|
error_msg = err.get("error", "Unknown error")
|
||||||
args = err.get("args", {})
|
args = err.get("args", {})
|
||||||
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||||
|
|
||||||
if len(validation_errors) > 3:
|
if len(validation_errors) > 3:
|
||||||
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
|
validation_errors_info += (
|
||||||
|
f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||||
|
)
|
||||||
param_samples_text = "\n\n".join(param_samples)
|
param_samples_text = "\n\n".join(param_samples)
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||||
|
|
||||||
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
|
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
|
||||||
|
|
||||||
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
|
|||||||
- overall_score: number (average of all scores, 0-10)
|
- overall_score: number (average of all scores, 0-10)
|
||||||
- feedback: string (focused ONLY on parameter value extraction quality)
|
- feedback: string (focused ONLY on parameter value extraction quality)
|
||||||
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
|
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
|
||||||
"""},
|
""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
{task_context}
|
{task_context}
|
||||||
|
|
||||||
@@ -226,7 +250,8 @@ Parameter extraction examples:
|
|||||||
{validation_errors_info}
|
{validation_errors_info}
|
||||||
|
|
||||||
Evaluate the quality of the agent's parameter extraction for this task.
|
Evaluate the quality of the agent's parameter extraction for this task.
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
@@ -251,18 +276,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
|
|||||||
if "improvement_suggestions" in evaluation_data:
|
if "improvement_suggestions" in evaluation_data:
|
||||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||||
else:
|
else:
|
||||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
feedback += evaluation_data.get(
|
||||||
|
"feedback", "No detailed feedback available."
|
||||||
|
)
|
||||||
|
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=overall_score,
|
score=overall_score, feedback=feedback, raw_response=response
|
||||||
feedback=feedback,
|
|
||||||
raw_response=response
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Error evaluating parameter extraction: {e}",
|
feedback=f"Error evaluating parameter extraction: {e}",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -274,7 +299,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
execution_trace: Dict[str, Any],
|
execution_trace: dict[str, Any],
|
||||||
final_output: str,
|
final_output: str,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
) -> EvaluationScore:
|
) -> EvaluationScore:
|
||||||
@@ -288,7 +313,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
|||||||
if tool_count == 0:
|
if tool_count == 0:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback="No tool usage detected. Cannot evaluate tool invocation."
|
feedback="No tool usage detected. Cannot evaluate tool invocation.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_use in tool_uses:
|
for tool_use in tool_uses:
|
||||||
@@ -296,7 +321,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
|||||||
error_info = {
|
error_info = {
|
||||||
"tool": tool_use.get("tool", "Unknown tool"),
|
"tool": tool_use.get("tool", "Unknown tool"),
|
||||||
"error": tool_use.get("result"),
|
"error": tool_use.get("result"),
|
||||||
"error_type": tool_use.get("error_type", "unknown_error")
|
"error_type": tool_use.get("error_type", "unknown_error"),
|
||||||
}
|
}
|
||||||
tool_errors.append(error_info)
|
tool_errors.append(error_info)
|
||||||
|
|
||||||
@@ -315,9 +340,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
|||||||
tool_args = tool_use.get("args", {})
|
tool_args = tool_use.get("args", {})
|
||||||
success = tool_use.get("success", True) and not tool_use.get("error", False)
|
success = tool_use.get("success", True) and not tool_use.get("error", False)
|
||||||
error_type = tool_use.get("error_type", "") if not success else ""
|
error_type = tool_use.get("error_type", "") if not success else ""
|
||||||
error_msg = tool_use.get("result", "No error") if not success else "No error"
|
error_msg = (
|
||||||
|
tool_use.get("result", "No error") if not success else "No error"
|
||||||
|
)
|
||||||
|
|
||||||
sample = f"Tool invocation #{i+1}:\n"
|
sample = f"Tool invocation #{i + 1}:\n"
|
||||||
sample += f"- Tool: {tool_name}\n"
|
sample += f"- Tool: {tool_name}\n"
|
||||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||||
sample += f"- Success: {'No' if not success else 'Yes'}\n"
|
sample += f"- Success: {'No' if not success else 'Yes'}\n"
|
||||||
@@ -330,11 +357,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
|||||||
if error_types:
|
if error_types:
|
||||||
error_type_summary = "Error type breakdown:\n"
|
error_type_summary = "Error type breakdown:\n"
|
||||||
for error_type, count in error_types.items():
|
for error_type, count in error_types.items():
|
||||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
|
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
|
||||||
|
|
||||||
invocation_samples_text = "\n\n".join(invocation_samples)
|
invocation_samples_text = "\n\n".join(invocation_samples)
|
||||||
prompt = [
|
prompt = [
|
||||||
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||||
|
|
||||||
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
|
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
|
||||||
|
|
||||||
@@ -359,8 +388,11 @@ Return your evaluation as JSON with these fields:
|
|||||||
- overall_score: number (average of all scores, 0-10)
|
- overall_score: number (average of all scores, 0-10)
|
||||||
- feedback: string (focused ONLY on structural aspects of tool invocation)
|
- feedback: string (focused ONLY on structural aspects of tool invocation)
|
||||||
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
|
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
|
||||||
"""},
|
""",
|
||||||
{"role": "user", "content": f"""
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Agent role: {agent.role}
|
Agent role: {agent.role}
|
||||||
{task_context}
|
{task_context}
|
||||||
|
|
||||||
@@ -371,7 +403,8 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
|
|||||||
{error_type_summary}
|
{error_type_summary}
|
||||||
|
|
||||||
Evaluate the quality of the agent's tool invocation structure during this task.
|
Evaluate the quality of the agent's tool invocation structure during this task.
|
||||||
"""}
|
""",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
assert self.llm is not None
|
assert self.llm is not None
|
||||||
@@ -388,23 +421,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
|
|||||||
overall_score = float(evaluation_data.get("overall_score", 5.0))
|
overall_score = float(evaluation_data.get("overall_score", 5.0))
|
||||||
|
|
||||||
feedback = "Tool Invocation Evaluation:\n"
|
feedback = "Tool Invocation Evaluation:\n"
|
||||||
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
feedback += (
|
||||||
|
f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||||
|
)
|
||||||
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
|
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
|
||||||
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
|
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
|
||||||
|
|
||||||
if "improvement_suggestions" in evaluation_data:
|
if "improvement_suggestions" in evaluation_data:
|
||||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||||
else:
|
else:
|
||||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
feedback += evaluation_data.get(
|
||||||
|
"feedback", "No detailed feedback available."
|
||||||
|
)
|
||||||
|
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=overall_score,
|
score=overall_score, feedback=feedback, raw_response=response
|
||||||
feedback=feedback,
|
|
||||||
raw_response=response
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return EvaluationScore(
|
return EvaluationScore(
|
||||||
score=None,
|
score=None,
|
||||||
feedback=f"Error evaluating tool invocation: {e}",
|
feedback=f"Error evaluating tool invocation: {e}",
|
||||||
raw_response=response
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,21 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
|
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
import warnings
|
|
||||||
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
|
|
||||||
from crewai import Crew, Agent
|
|
||||||
|
|
||||||
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
|
from crewai import Agent, Crew
|
||||||
failed_tests = [result for result in experiment_results.results if not result.passed]
|
from crewai.experimental.evaluation.experiment import (
|
||||||
|
ExperimentResults,
|
||||||
|
ExperimentRunner,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_experiment_successfully(
|
||||||
|
experiment_results: ExperimentResults, baseline_filepath: str | None = None
|
||||||
|
) -> None:
|
||||||
|
failed_tests = [
|
||||||
|
result for result in experiment_results.results if not result.passed
|
||||||
|
]
|
||||||
|
|
||||||
if failed_tests:
|
if failed_tests:
|
||||||
detailed_failures: list[str] = []
|
detailed_failures: list[str] = []
|
||||||
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
|
|||||||
for result in failed_tests:
|
for result in failed_tests:
|
||||||
expected = result.expected_score
|
expected = result.expected_score
|
||||||
actual = result.score
|
actual = result.score
|
||||||
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
|
detailed_failures.append(
|
||||||
|
f"- {result.identifier}: expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
failure_details = "\n".join(detailed_failures)
|
failure_details = "\n".join(detailed_failures)
|
||||||
raise AssertionError(f"The following test cases failed:\n{failure_details}")
|
raise AssertionError(f"The following test cases failed:\n{failure_details}")
|
||||||
|
|
||||||
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
|
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
|
||||||
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
|
comparison = experiment_results.compare_with_baseline(
|
||||||
|
baseline_filepath=baseline_filepath
|
||||||
|
)
|
||||||
assert_experiment_no_regression(comparison)
|
assert_experiment_no_regression(comparison)
|
||||||
|
|
||||||
|
|
||||||
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
|
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
|
||||||
regressed = comparison_result.get("regressed", [])
|
regressed = comparison_result.get("regressed", [])
|
||||||
if regressed:
|
if regressed:
|
||||||
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
|
raise AssertionError(
|
||||||
|
f"Regression detected! The following tests that previously passed now fail: {regressed}"
|
||||||
|
)
|
||||||
|
|
||||||
missing_tests = comparison_result.get("missing_tests", [])
|
missing_tests = comparison_result.get("missing_tests", [])
|
||||||
if missing_tests:
|
if missing_tests:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
|
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
|
||||||
UserWarning
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
|
|
||||||
|
def run_experiment(
|
||||||
|
dataset: list[dict[str, Any]],
|
||||||
|
crew: Crew | None = None,
|
||||||
|
agents: list[Agent] | None = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> ExperimentResults:
|
||||||
runner = ExperimentRunner(dataset=dataset)
|
runner = ExperimentRunner(dataset=dataset)
|
||||||
|
|
||||||
return runner.run(agents=agents, crew=crew, print_summary=verbose)
|
return runner.run(agents=agents, crew=crew, print_summary=verbose)
|
||||||
|
|
||||||
|
|
||||||
def _get_baseline_filepath_fallback() -> str:
|
def _get_baseline_filepath_fallback() -> str:
|
||||||
test_func_name = "experiment_fallback"
|
test_func_name = "experiment_fallback"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_frame = inspect.currentframe()
|
current_frame = inspect.currentframe()
|
||||||
if current_frame is not None:
|
if current_frame is not None:
|
||||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||||
except Exception:
|
except Exception:
|
||||||
...
|
...
|
||||||
return f"{test_func_name}_results.json"
|
return f"{test_func_name}_results.json"
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from crewai.flow.flow import Flow, start, listen, or_, and_, router
|
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||||
from crewai.flow.persistence import persist
|
from crewai.flow.persistence import persist
|
||||||
|
|
||||||
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
|
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||||
|
|
||||||
|
|||||||
@@ -1086,7 +1086,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
for method_name in self._start_methods:
|
for method_name in self._start_methods:
|
||||||
# Check if this start method is triggered by the current trigger
|
# Check if this start method is triggered by the current trigger
|
||||||
if method_name in self._listeners:
|
if method_name in self._listeners:
|
||||||
condition_type, trigger_methods = self._listeners[
|
_condition_type, trigger_methods = self._listeners[
|
||||||
method_name
|
method_name
|
||||||
]
|
]
|
||||||
if current_trigger in trigger_methods:
|
if current_trigger in trigger_methods:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||||
|
|
||||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
|||||||
inspecting the call stack.
|
inspecting the call stack.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
parent_flow: InstanceOf[Flow] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The parent flow of the instance, if it was created inside a flow.",
|
description="The parent flow of the instance, if it was created inside a flow.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
# flow_visualizer.py
|
# flow_visualizer.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pyvis.network import Network
|
from pyvis.network import Network
|
||||||
|
|
||||||
from crewai.flow.config import COLORS, NODE_STYLES
|
from crewai.flow.config import COLORS, NODE_STYLES
|
||||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
from crewai.flow.path_utils import safe_path_join
|
||||||
from crewai.flow.utils import calculate_node_levels
|
from crewai.flow.utils import calculate_node_levels
|
||||||
from crewai.flow.visualization_utils import (
|
from crewai.flow.visualization_utils import (
|
||||||
add_edges,
|
add_edges,
|
||||||
@@ -34,13 +33,13 @@ class FlowPlot:
|
|||||||
ValueError
|
ValueError
|
||||||
If flow object is invalid or missing required attributes.
|
If flow object is invalid or missing required attributes.
|
||||||
"""
|
"""
|
||||||
if not hasattr(flow, '_methods'):
|
if not hasattr(flow, "_methods"):
|
||||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||||
if not hasattr(flow, '_listeners'):
|
if not hasattr(flow, "_listeners"):
|
||||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||||
if not hasattr(flow, '_start_methods'):
|
if not hasattr(flow, "_start_methods"):
|
||||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||||
|
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.colors = COLORS
|
self.colors = COLORS
|
||||||
self.node_styles = NODE_STYLES
|
self.node_styles = NODE_STYLES
|
||||||
@@ -65,7 +64,7 @@ class FlowPlot:
|
|||||||
"""
|
"""
|
||||||
if not filename or not isinstance(filename, str):
|
if not filename or not isinstance(filename, str):
|
||||||
raise ValueError("Filename must be a non-empty string")
|
raise ValueError("Filename must be a non-empty string")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize network
|
# Initialize network
|
||||||
net = Network(
|
net = Network(
|
||||||
@@ -96,32 +95,32 @@ class FlowPlot:
|
|||||||
try:
|
try:
|
||||||
node_levels = calculate_node_levels(self.flow)
|
node_levels = calculate_node_levels(self.flow)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
raise ValueError(f"Failed to calculate node levels: {e!s}")
|
||||||
|
|
||||||
# Compute positions
|
# Compute positions
|
||||||
try:
|
try:
|
||||||
node_positions = compute_positions(self.flow, node_levels)
|
node_positions = compute_positions(self.flow, node_levels)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
raise ValueError(f"Failed to compute node positions: {e!s}")
|
||||||
|
|
||||||
# Add nodes to the network
|
# Add nodes to the network
|
||||||
try:
|
try:
|
||||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
raise RuntimeError(f"Failed to add nodes to network: {e!s}")
|
||||||
|
|
||||||
# Add edges to the network
|
# Add edges to the network
|
||||||
try:
|
try:
|
||||||
add_edges(net, self.flow, node_positions, self.colors)
|
add_edges(net, self.flow, node_positions, self.colors)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
raise RuntimeError(f"Failed to add edges to network: {e!s}")
|
||||||
|
|
||||||
# Generate HTML
|
# Generate HTML
|
||||||
try:
|
try:
|
||||||
network_html = net.generate_html()
|
network_html = net.generate_html()
|
||||||
final_html_content = self._generate_final_html(network_html)
|
final_html_content = self._generate_final_html(network_html)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
raise RuntimeError(f"Failed to generate network visualization: {e!s}")
|
||||||
|
|
||||||
# Save the final HTML content to the file
|
# Save the final HTML content to the file
|
||||||
try:
|
try:
|
||||||
@@ -129,12 +128,14 @@ class FlowPlot:
|
|||||||
f.write(final_html_content)
|
f.write(final_html_content)
|
||||||
print(f"Plot saved as {filename}.html")
|
print(f"Plot saved as {filename}.html")
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
raise IOError(
|
||||||
|
f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
except (ValueError, RuntimeError, IOError) as e:
|
except (ValueError, RuntimeError, IOError) as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
raise RuntimeError(f"Unexpected error during flow visualization: {e!s}")
|
||||||
finally:
|
finally:
|
||||||
self._cleanup_pyvis_lib()
|
self._cleanup_pyvis_lib()
|
||||||
|
|
||||||
@@ -165,7 +166,9 @@ class FlowPlot:
|
|||||||
try:
|
try:
|
||||||
# Extract just the body content from the generated HTML
|
# Extract just the body content from the generated HTML
|
||||||
current_dir = os.path.dirname(__file__)
|
current_dir = os.path.dirname(__file__)
|
||||||
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
|
template_path = safe_path_join(
|
||||||
|
"assets", "crewai_flow_visual_template.html", root=current_dir
|
||||||
|
)
|
||||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||||
|
|
||||||
if not os.path.exists(template_path):
|
if not os.path.exists(template_path):
|
||||||
@@ -179,12 +182,9 @@ class FlowPlot:
|
|||||||
# Generate the legend items HTML
|
# Generate the legend items HTML
|
||||||
legend_items = get_legend_items(self.colors)
|
legend_items = get_legend_items(self.colors)
|
||||||
legend_items_html = generate_legend_items_html(legend_items)
|
legend_items_html = generate_legend_items_html(legend_items)
|
||||||
final_html_content = html_handler.generate_final_html(
|
return html_handler.generate_final_html(network_body, legend_items_html)
|
||||||
network_body, legend_items_html
|
|
||||||
)
|
|
||||||
return final_html_content
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
raise IOError(f"Failed to generate visualization HTML: {e!s}")
|
||||||
|
|
||||||
def _cleanup_pyvis_lib(self):
|
def _cleanup_pyvis_lib(self):
|
||||||
"""
|
"""
|
||||||
@@ -197,6 +197,7 @@ class FlowPlot:
|
|||||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(lib_folder)
|
shutil.rmtree(lib_folder)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"Error validating lib folder path: {e}")
|
print(f"Error validating lib folder path: {e}")
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
from crewai.flow.path_utils import validate_path_exists
|
||||||
|
|
||||||
|
|
||||||
class HTMLTemplateHandler:
|
class HTMLTemplateHandler:
|
||||||
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
|
|||||||
if "border" in item:
|
if "border" in item:
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item["label"]}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
elif item.get("dashed") is not None:
|
elif item.get("dashed") is not None:
|
||||||
style = "dashed" if item["dashed"] else "solid"
|
style = "dashed" if item["dashed"] else "solid"
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item["label"]}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item["label"]}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
return legend_items_html
|
return legend_items_html
|
||||||
@@ -86,8 +85,6 @@ class HTMLTemplateHandler:
|
|||||||
final_html_content = final_html_content.replace(
|
final_html_content = final_html_content.replace(
|
||||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||||
)
|
)
|
||||||
final_html_content = final_html_content.replace(
|
return final_html_content.replace(
|
||||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||||
)
|
)
|
||||||
|
|
||||||
return final_html_content
|
|
||||||
|
|||||||
@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
|
|||||||
traversal attacks and ensure paths remain within allowed boundaries.
|
traversal attacks and ensure paths remain within allowed boundaries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
|
|
||||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Safely join path components and ensure the result is within allowed boundaries.
|
Safely join path components and ensure the result is within allowed boundaries.
|
||||||
|
|
||||||
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
|||||||
|
|
||||||
# Establish root directory
|
# Establish root directory
|
||||||
root_path = Path(root).resolve() if root else Path.cwd()
|
root_path = Path(root).resolve() if root else Path.cwd()
|
||||||
|
|
||||||
# Join and resolve the full path
|
# Join and resolve the full path
|
||||||
full_path = Path(root_path, *clean_parts).resolve()
|
full_path = Path(root_path, *clean_parts).resolve()
|
||||||
|
|
||||||
# Check if the resolved path is within root
|
# Check if the resolved path is within root
|
||||||
if not str(full_path).startswith(str(root_path)):
|
if not str(full_path).startswith(str(root_path)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return str(full_path)
|
return str(full_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
raise
|
raise
|
||||||
raise ValueError(f"Invalid path components: {str(e)}")
|
raise ValueError(f"Invalid path components: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||||
"""
|
"""
|
||||||
Validate that a path exists and is of the expected type.
|
Validate that a path exists and is of the expected type.
|
||||||
|
|
||||||
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
path_obj = Path(path).resolve()
|
path_obj = Path(path).resolve()
|
||||||
|
|
||||||
if not path_obj.exists():
|
if not path_obj.exists():
|
||||||
raise ValueError(f"Path does not exist: {path}")
|
raise ValueError(f"Path does not exist: {path}")
|
||||||
|
|
||||||
if file_type == "file" and not path_obj.is_file():
|
if file_type == "file" and not path_obj.is_file():
|
||||||
raise ValueError(f"Path is not a file: {path}")
|
raise ValueError(f"Path is not a file: {path}")
|
||||||
elif file_type == "directory" and not path_obj.is_dir():
|
if file_type == "directory" and not path_obj.is_dir():
|
||||||
raise ValueError(f"Path is not a directory: {path}")
|
raise ValueError(f"Path is not a directory: {path}")
|
||||||
|
|
||||||
return str(path_obj)
|
return str(path_obj)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
raise
|
raise
|
||||||
raise ValueError(f"Invalid path: {str(e)}")
|
raise ValueError(f"Invalid path: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||||
"""
|
"""
|
||||||
Safely list files in a directory matching a pattern.
|
Safely list files in a directory matching a pattern.
|
||||||
|
|
||||||
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
|||||||
dir_path = Path(directory).resolve()
|
dir_path = Path(directory).resolve()
|
||||||
if not dir_path.is_dir():
|
if not dir_path.is_dir():
|
||||||
raise ValueError(f"Not a directory: {directory}")
|
raise ValueError(f"Not a directory: {directory}")
|
||||||
|
|
||||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
raise
|
raise
|
||||||
raise ValueError(f"Error listing files: {str(e)}")
|
raise ValueError(f"Error listing files: {e!s}")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
|
|||||||
from crewai.flow.persistence.decorators import persist
|
from crewai.flow.persistence.decorators import persist
|
||||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||||
|
|
||||||
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
|
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
|
||||||
|
|
||||||
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
|
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
|
||||||
DictStateType = Dict[str, Any]
|
DictStateType = dict[str, Any]
|
||||||
|
|||||||
@@ -1,53 +1,47 @@
|
|||||||
"""Base class for flow state persistence."""
|
"""Base class for flow state persistence."""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class FlowPersistence(abc.ABC):
|
class FlowPersistence(abc.ABC):
|
||||||
"""Abstract base class for flow state persistence.
|
"""Abstract base class for flow state persistence.
|
||||||
|
|
||||||
This class defines the interface that all persistence implementations must follow.
|
This class defines the interface that all persistence implementations must follow.
|
||||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def init_db(self) -> None:
|
def init_db(self) -> None:
|
||||||
"""Initialize the persistence backend.
|
"""Initialize the persistence backend.
|
||||||
|
|
||||||
This method should handle any necessary setup, such as:
|
This method should handle any necessary setup, such as:
|
||||||
- Creating tables
|
- Creating tables
|
||||||
- Establishing connections
|
- Establishing connections
|
||||||
- Setting up indexes
|
- Setting up indexes
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def save_state(
|
def save_state(
|
||||||
self,
|
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
|
||||||
flow_uuid: str,
|
|
||||||
method_name: str,
|
|
||||||
state_data: Union[Dict[str, Any], BaseModel]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Persist the flow state after method completion.
|
"""Persist the flow state after method completion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flow_uuid: Unique identifier for the flow instance
|
flow_uuid: Unique identifier for the flow instance
|
||||||
method_name: Name of the method that just completed
|
method_name: Name of the method that just completed
|
||||||
state_data: Current state data (either dict or Pydantic model)
|
state_data: Current state data (either dict or Pydantic model)
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||||
"""Load the most recent state for a given flow UUID.
|
"""Load the most recent state for a given flow UUID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flow_uuid: Unique identifier for the flow instance
|
flow_uuid: Unique identifier for the flow instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The most recent state as a dictionary, or None if no state exists
|
The most recent state as a dictionary, or None if no state exists
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|||||||
@@ -24,13 +24,10 @@ Example:
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
|||||||
"save_state": "Saving flow state to memory for ID: {}",
|
"save_state": "Saving flow state to memory for ID: {}",
|
||||||
"save_error": "Failed to persist state for method {}: {}",
|
"save_error": "Failed to persist state for method {}: {}",
|
||||||
"state_missing": "Flow instance has no state",
|
"state_missing": "Flow instance has no state",
|
||||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +55,13 @@ class PersistenceDecorator:
|
|||||||
_printer = Printer() # Class-level printer instance
|
_printer = Printer() # Class-level printer instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
|
def persist_state(
|
||||||
|
cls,
|
||||||
|
flow_instance: Any,
|
||||||
|
method_name: str,
|
||||||
|
persistence_instance: FlowPersistence,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Persist flow state with proper error handling and logging.
|
"""Persist flow state with proper error handling and logging.
|
||||||
|
|
||||||
This method handles the persistence of flow state data, including proper
|
This method handles the persistence of flow state data, including proper
|
||||||
@@ -76,22 +79,24 @@ class PersistenceDecorator:
|
|||||||
AttributeError: If flow instance lacks required state attributes
|
AttributeError: If flow instance lacks required state attributes
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
state = getattr(flow_instance, 'state', None)
|
state = getattr(flow_instance, "state", None)
|
||||||
if state is None:
|
if state is None:
|
||||||
raise ValueError("Flow instance has no state")
|
raise ValueError("Flow instance has no state")
|
||||||
|
|
||||||
flow_uuid: Optional[str] = None
|
flow_uuid: str | None = None
|
||||||
if isinstance(state, dict):
|
if isinstance(state, dict):
|
||||||
flow_uuid = state.get('id')
|
flow_uuid = state.get("id")
|
||||||
elif isinstance(state, BaseModel):
|
elif isinstance(state, BaseModel):
|
||||||
flow_uuid = getattr(state, 'id', None)
|
flow_uuid = getattr(state, "id", None)
|
||||||
|
|
||||||
if not flow_uuid:
|
if not flow_uuid:
|
||||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||||
|
|
||||||
# Log state saving only if verbose is True
|
# Log state saving only if verbose is True
|
||||||
if verbose:
|
if verbose:
|
||||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
cls._printer.print(
|
||||||
|
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
|
||||||
|
)
|
||||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -104,7 +109,7 @@ class PersistenceDecorator:
|
|||||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||||
cls._printer.print(error_msg, color="red")
|
cls._printer.print(error_msg, color="red")
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
raise RuntimeError(f"State persistence failed: {e!s}") from e
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
error_msg = LOG_MESSAGES["state_missing"]
|
error_msg = LOG_MESSAGES["state_missing"]
|
||||||
cls._printer.print(error_msg, color="red")
|
cls._printer.print(error_msg, color="red")
|
||||||
@@ -117,7 +122,7 @@ class PersistenceDecorator:
|
|||||||
raise ValueError(error_msg) from e
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||||
"""Decorator to persist flow state.
|
"""Decorator to persist flow state.
|
||||||
|
|
||||||
This decorator can be applied at either the class level or method level.
|
This decorator can be applied at either the class level or method level.
|
||||||
@@ -144,32 +149,33 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
|||||||
def begin(self):
|
def begin(self):
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
|
||||||
|
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||||
"""Decorator that handles both class and method decoration."""
|
"""Decorator that handles both class and method decoration."""
|
||||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||||
|
|
||||||
if isinstance(target, type):
|
if isinstance(target, type):
|
||||||
# Class decoration
|
# Class decoration
|
||||||
original_init = getattr(target, "__init__")
|
original_init = target.__init__
|
||||||
|
|
||||||
@functools.wraps(original_init)
|
@functools.wraps(original_init)
|
||||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||||
if 'persistence' not in kwargs:
|
if "persistence" not in kwargs:
|
||||||
kwargs['persistence'] = actual_persistence
|
kwargs["persistence"] = actual_persistence
|
||||||
original_init(self, *args, **kwargs)
|
original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
setattr(target, "__init__", new_init)
|
target.__init__ = new_init
|
||||||
|
|
||||||
# Store original methods to preserve their decorators
|
# Store original methods to preserve their decorators
|
||||||
original_methods = {}
|
original_methods = {}
|
||||||
|
|
||||||
for name, method in target.__dict__.items():
|
for name, method in target.__dict__.items():
|
||||||
if callable(method) and (
|
if callable(method) and (
|
||||||
hasattr(method, "__is_start_method__") or
|
hasattr(method, "__is_start_method__")
|
||||||
hasattr(method, "__trigger_methods__") or
|
or hasattr(method, "__trigger_methods__")
|
||||||
hasattr(method, "__condition_type__") or
|
or hasattr(method, "__condition_type__")
|
||||||
hasattr(method, "__is_flow_method__") or
|
or hasattr(method, "__is_flow_method__")
|
||||||
hasattr(method, "__is_router__")
|
or hasattr(method, "__is_router__")
|
||||||
):
|
):
|
||||||
original_methods[name] = method
|
original_methods[name] = method
|
||||||
|
|
||||||
@@ -177,78 +183,116 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
|||||||
for name, method in original_methods.items():
|
for name, method in original_methods.items():
|
||||||
if asyncio.iscoroutinefunction(method):
|
if asyncio.iscoroutinefunction(method):
|
||||||
# Create a closure to capture the current name and method
|
# Create a closure to capture the current name and method
|
||||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
def create_async_wrapper(
|
||||||
|
method_name: str, original_method: Callable
|
||||||
|
):
|
||||||
@functools.wraps(original_method)
|
@functools.wraps(original_method)
|
||||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
async def method_wrapper(
|
||||||
|
self: Any, *args: Any, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
result = await original_method(self, *args, **kwargs)
|
result = await original_method(self, *args, **kwargs)
|
||||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
PersistenceDecorator.persist_state(
|
||||||
|
self, method_name, actual_persistence, verbose
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return method_wrapper
|
return method_wrapper
|
||||||
|
|
||||||
wrapped = create_async_wrapper(name, method)
|
wrapped = create_async_wrapper(name, method)
|
||||||
|
|
||||||
# Preserve all original decorators and attributes
|
# Preserve all original decorators and attributes
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
for attr in [
|
||||||
|
"__is_start_method__",
|
||||||
|
"__trigger_methods__",
|
||||||
|
"__condition_type__",
|
||||||
|
"__is_router__",
|
||||||
|
]:
|
||||||
if hasattr(method, attr):
|
if hasattr(method, attr):
|
||||||
setattr(wrapped, attr, getattr(method, attr))
|
setattr(wrapped, attr, getattr(method, attr))
|
||||||
setattr(wrapped, "__is_flow_method__", True)
|
wrapped.__is_flow_method__ = True
|
||||||
|
|
||||||
# Update the class with the wrapped method
|
# Update the class with the wrapped method
|
||||||
setattr(target, name, wrapped)
|
setattr(target, name, wrapped)
|
||||||
else:
|
else:
|
||||||
# Create a closure to capture the current name and method
|
# Create a closure to capture the current name and method
|
||||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
def create_sync_wrapper(
|
||||||
|
method_name: str, original_method: Callable
|
||||||
|
):
|
||||||
@functools.wraps(original_method)
|
@functools.wraps(original_method)
|
||||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
result = original_method(self, *args, **kwargs)
|
result = original_method(self, *args, **kwargs)
|
||||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
PersistenceDecorator.persist_state(
|
||||||
|
self, method_name, actual_persistence, verbose
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return method_wrapper
|
return method_wrapper
|
||||||
|
|
||||||
wrapped = create_sync_wrapper(name, method)
|
wrapped = create_sync_wrapper(name, method)
|
||||||
|
|
||||||
# Preserve all original decorators and attributes
|
# Preserve all original decorators and attributes
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
for attr in [
|
||||||
|
"__is_start_method__",
|
||||||
|
"__trigger_methods__",
|
||||||
|
"__condition_type__",
|
||||||
|
"__is_router__",
|
||||||
|
]:
|
||||||
if hasattr(method, attr):
|
if hasattr(method, attr):
|
||||||
setattr(wrapped, attr, getattr(method, attr))
|
setattr(wrapped, attr, getattr(method, attr))
|
||||||
setattr(wrapped, "__is_flow_method__", True)
|
wrapped.__is_flow_method__ = True
|
||||||
|
|
||||||
# Update the class with the wrapped method
|
# Update the class with the wrapped method
|
||||||
setattr(target, name, wrapped)
|
setattr(target, name, wrapped)
|
||||||
|
|
||||||
return target
|
return target
|
||||||
else:
|
# Method decoration
|
||||||
# Method decoration
|
method = target
|
||||||
method = target
|
method.__is_flow_method__ = True
|
||||||
setattr(method, "__is_flow_method__", True)
|
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(method):
|
if asyncio.iscoroutinefunction(method):
|
||||||
@functools.wraps(method)
|
|
||||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
|
||||||
method_coro = method(flow_instance, *args, **kwargs)
|
|
||||||
if asyncio.iscoroutine(method_coro):
|
|
||||||
result = await method_coro
|
|
||||||
else:
|
|
||||||
result = method_coro
|
|
||||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
|
||||||
return result
|
|
||||||
|
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
@functools.wraps(method)
|
||||||
if hasattr(method, attr):
|
async def method_async_wrapper(
|
||||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
flow_instance: Any, *args: Any, **kwargs: Any
|
||||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
) -> T:
|
||||||
return cast(Callable[..., T], method_async_wrapper)
|
method_coro = method(flow_instance, *args, **kwargs)
|
||||||
else:
|
if asyncio.iscoroutine(method_coro):
|
||||||
@functools.wraps(method)
|
result = await method_coro
|
||||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
else:
|
||||||
result = method(flow_instance, *args, **kwargs)
|
result = method_coro
|
||||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
PersistenceDecorator.persist_state(
|
||||||
return result
|
flow_instance, method.__name__, actual_persistence, verbose
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
for attr in [
|
||||||
if hasattr(method, attr):
|
"__is_start_method__",
|
||||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
"__trigger_methods__",
|
||||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
"__condition_type__",
|
||||||
return cast(Callable[..., T], method_sync_wrapper)
|
"__is_router__",
|
||||||
|
]:
|
||||||
|
if hasattr(method, attr):
|
||||||
|
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||||
|
method_async_wrapper.__is_flow_method__ = True
|
||||||
|
return cast(Callable[..., T], method_async_wrapper)
|
||||||
|
|
||||||
|
@functools.wraps(method)
|
||||||
|
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||||
|
result = method(flow_instance, *args, **kwargs)
|
||||||
|
PersistenceDecorator.persist_state(
|
||||||
|
flow_instance, method.__name__, actual_persistence, verbose
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
for attr in [
|
||||||
|
"__is_start_method__",
|
||||||
|
"__trigger_methods__",
|
||||||
|
"__condition_type__",
|
||||||
|
"__is_router__",
|
||||||
|
]:
|
||||||
|
if hasattr(method, attr):
|
||||||
|
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||||
|
method_sync_wrapper.__is_flow_method__ = True
|
||||||
|
return cast(Callable[..., T], method_sync_wrapper)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
|
|
||||||
db_path: str
|
db_path: str
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None):
|
def __init__(self, db_path: str | None = None):
|
||||||
"""Initialize SQLite persistence.
|
"""Initialize SQLite persistence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
self,
|
self,
|
||||||
flow_uuid: str,
|
flow_uuid: str,
|
||||||
method_name: str,
|
method_name: str,
|
||||||
state_data: Union[Dict[str, Any], BaseModel],
|
state_data: dict[str, Any] | BaseModel,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save the current flow state to SQLite.
|
"""Save the current flow state to SQLite.
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||||
"""Load the most recent state for a given flow UUID.
|
"""Load the most recent state for a given flow UUID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ the Flow system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from typing_extensions import NotRequired, Required
|
from typing_extensions import NotRequired, Required
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ import ast
|
|||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||||
try:
|
try:
|
||||||
source = inspect.getsource(function)
|
source = inspect.getsource(function)
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
|||||||
return list(return_values) if return_values else None
|
return list(return_values) if return_values else None
|
||||||
|
|
||||||
|
|
||||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Calculate the hierarchical level of each node in the flow.
|
Calculate the hierarchical level of each node in the flow.
|
||||||
|
|
||||||
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
|||||||
- Handles both OR and AND conditions for listeners
|
- Handles both OR and AND conditions for listeners
|
||||||
- Processes router paths separately
|
- Processes router paths separately
|
||||||
"""
|
"""
|
||||||
levels: Dict[str, int] = {}
|
levels: dict[str, int] = {}
|
||||||
queue: Deque[str] = deque()
|
queue: deque[str] = deque()
|
||||||
visited: Set[str] = set()
|
visited: set[str] = set()
|
||||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
pending_and_listeners: dict[str, set[str]] = {}
|
||||||
|
|
||||||
# Make all start methods at level 0
|
# Make all start methods at level 0
|
||||||
for method_name, method in flow._methods.items():
|
for method_name, method in flow._methods.items():
|
||||||
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
|||||||
return levels
|
return levels
|
||||||
|
|
||||||
|
|
||||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Count the number of outgoing edges for each method in the flow.
|
Count the number of outgoing edges for each method in the flow.
|
||||||
|
|
||||||
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
|
||||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||||
"""
|
"""
|
||||||
Build a dictionary mapping each node to its ancestor nodes.
|
Build a dictionary mapping each node to its ancestor nodes.
|
||||||
|
|
||||||
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
|||||||
Dict[str, Set[str]]
|
Dict[str, Set[str]]
|
||||||
Dictionary mapping each node to a set of its ancestor nodes.
|
Dictionary mapping each node to a set of its ancestor nodes.
|
||||||
"""
|
"""
|
||||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||||
visited: Set[str] = set()
|
visited: set[str] = set()
|
||||||
for node in flow._methods:
|
for node in flow._methods:
|
||||||
if node not in visited:
|
if node not in visited:
|
||||||
dfs_ancestors(node, ancestors, visited, flow)
|
dfs_ancestors(node, ancestors, visited, flow)
|
||||||
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
|||||||
|
|
||||||
|
|
||||||
def dfs_ancestors(
|
def dfs_ancestors(
|
||||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Perform depth-first search to build ancestor relationships.
|
Perform depth-first search to build ancestor relationships.
|
||||||
@@ -265,7 +265,7 @@ def dfs_ancestors(
|
|||||||
|
|
||||||
|
|
||||||
def is_ancestor(
|
def is_ancestor(
|
||||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if one node is an ancestor of another.
|
Check if one node is an ancestor of another.
|
||||||
@@ -287,7 +287,7 @@ def is_ancestor(
|
|||||||
return ancestor_candidate in ancestors.get(node, set())
|
return ancestor_candidate in ancestors.get(node, set())
|
||||||
|
|
||||||
|
|
||||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Build a dictionary mapping parent nodes to their children.
|
Build a dictionary mapping parent nodes to their children.
|
||||||
|
|
||||||
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
|||||||
- Maps router methods to their paths and listeners
|
- Maps router methods to their paths and listeners
|
||||||
- Children lists are sorted for consistent ordering
|
- Children lists are sorted for consistent ordering
|
||||||
"""
|
"""
|
||||||
parent_children: Dict[str, List[str]] = {}
|
parent_children: dict[str, list[str]] = {}
|
||||||
|
|
||||||
# Map listeners to their trigger methods
|
# Map listeners to their trigger methods
|
||||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||||
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
|||||||
|
|
||||||
|
|
||||||
def get_child_index(
|
def get_child_index(
|
||||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
parent: str, child: str, parent_children: dict[str, list[str]]
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get the index of a child node in its parent's sorted children list.
|
Get the index of a child node in its parent's sorted children list.
|
||||||
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
|
|||||||
paths = flow._router_paths.get(current, [])
|
paths = flow._router_paths.get(current, [])
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for listener_name, (
|
for listener_name, (
|
||||||
condition_type,
|
_condition_type,
|
||||||
trigger_methods,
|
trigger_methods,
|
||||||
) in flow._listeners.items():
|
) in flow._listeners.items():
|
||||||
if path in trigger_methods:
|
if path in trigger_methods:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Example
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
build_ancestor_dict,
|
build_ancestor_dict,
|
||||||
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
|
|||||||
|
|
||||||
class CrewCallVisitor(ast.NodeVisitor):
|
class CrewCallVisitor(ast.NodeVisitor):
|
||||||
"""AST visitor to detect .crew() method calls."""
|
"""AST visitor to detect .crew() method calls."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.found = False
|
self.found = False
|
||||||
|
|
||||||
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
|
|||||||
def add_nodes_to_network(
|
def add_nodes_to_network(
|
||||||
net: Any,
|
net: Any,
|
||||||
flow: Any,
|
flow: Any,
|
||||||
node_positions: Dict[str, Tuple[float, float]],
|
node_positions: dict[str, tuple[float, float]],
|
||||||
node_styles: Dict[str, Dict[str, Any]]
|
node_styles: dict[str, dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add nodes to the network visualization with appropriate styling.
|
Add nodes to the network visualization with appropriate styling.
|
||||||
@@ -98,6 +99,7 @@ def add_nodes_to_network(
|
|||||||
- Crew methods
|
- Crew methods
|
||||||
- Regular methods
|
- Regular methods
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def human_friendly_label(method_name):
|
def human_friendly_label(method_name):
|
||||||
return method_name.replace("_", " ").title()
|
return method_name.replace("_", " ").title()
|
||||||
|
|
||||||
@@ -138,10 +140,10 @@ def add_nodes_to_network(
|
|||||||
|
|
||||||
def compute_positions(
|
def compute_positions(
|
||||||
flow: Any,
|
flow: Any,
|
||||||
node_levels: Dict[str, int],
|
node_levels: dict[str, int],
|
||||||
y_spacing: float = 150,
|
y_spacing: float = 150,
|
||||||
x_spacing: float = 300
|
x_spacing: float = 300,
|
||||||
) -> Dict[str, Tuple[float, float]]:
|
) -> dict[str, tuple[float, float]]:
|
||||||
"""
|
"""
|
||||||
Compute the (x, y) positions for each node in the flow graph.
|
Compute the (x, y) positions for each node in the flow graph.
|
||||||
|
|
||||||
@@ -161,8 +163,8 @@ def compute_positions(
|
|||||||
Dict[str, Tuple[float, float]]
|
Dict[str, Tuple[float, float]]
|
||||||
Dictionary mapping node names to their (x, y) coordinates.
|
Dictionary mapping node names to their (x, y) coordinates.
|
||||||
"""
|
"""
|
||||||
level_nodes: Dict[int, List[str]] = {}
|
level_nodes: dict[int, list[str]] = {}
|
||||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
node_positions: dict[str, tuple[float, float]] = {}
|
||||||
|
|
||||||
for method_name, level in node_levels.items():
|
for method_name, level in node_levels.items():
|
||||||
level_nodes.setdefault(level, []).append(method_name)
|
level_nodes.setdefault(level, []).append(method_name)
|
||||||
@@ -180,10 +182,10 @@ def compute_positions(
|
|||||||
def add_edges(
|
def add_edges(
|
||||||
net: Any,
|
net: Any,
|
||||||
flow: Any,
|
flow: Any,
|
||||||
node_positions: Dict[str, Tuple[float, float]],
|
node_positions: dict[str, tuple[float, float]],
|
||||||
colors: Dict[str, str]
|
colors: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||||
"""
|
"""
|
||||||
Add edges to the network visualization with appropriate styling.
|
Add edges to the network visualization with appropriate styling.
|
||||||
|
|
||||||
@@ -269,7 +271,7 @@ def add_edges(
|
|||||||
for router_method_name, paths in flow._router_paths.items():
|
for router_method_name, paths in flow._router_paths.items():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for listener_name, (
|
for listener_name, (
|
||||||
condition_type,
|
_condition_type,
|
||||||
trigger_methods,
|
trigger_methods,
|
||||||
) in flow._listeners.items():
|
) in flow._listeners.items():
|
||||||
if path in trigger_methods:
|
if path in trigger_methods:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -14,19 +13,19 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Base class for knowledge sources that load content from files."""
|
"""Base class for knowledge sources that load content from files."""
|
||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||||
)
|
)
|
||||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||||
default_factory=list, description="The path to the file"
|
default_factory=list, description="The path to the file"
|
||||||
)
|
)
|
||||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
storage: KnowledgeStorage | None = Field(default=None)
|
||||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info):
|
def validate_file_path(self, v, info):
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -46,9 +45,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
self.content = self.load_content()
|
self.content = self.load_content()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_content(self) -> Dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||||
pass
|
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self):
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
@@ -74,11 +72,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
def convert_to_path(self, path: Path | str) -> Path:
|
||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|
||||||
def _process_file_paths(self) -> List[Path]:
|
def _process_file_paths(self) -> list[Path]:
|
||||||
"""Convert file_path to a list of Path objects."""
|
"""Convert file_path to a list of Path objects."""
|
||||||
|
|
||||||
if hasattr(self, "file_path") and self.file_path is not None:
|
if hasattr(self, "file_path") and self.file_path is not None:
|
||||||
@@ -93,7 +91,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
raise ValueError("Your source must be provided with a file_paths: []")
|
raise ValueError("Your source must be provided with a file_paths: []")
|
||||||
|
|
||||||
# Convert single path to list
|
# Convert single path to list
|
||||||
path_list: List[Union[Path, str]] = (
|
path_list: list[Path | str] = (
|
||||||
[self.file_paths]
|
[self.file_paths]
|
||||||
if isinstance(self.file_paths, (str, Path))
|
if isinstance(self.file_paths, (str, Path))
|
||||||
else list(self.file_paths)
|
else list(self.file_paths)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
@@ -12,29 +12,27 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
|
|
||||||
chunk_size: int = 4000
|
chunk_size: int = 4000
|
||||||
chunk_overlap: int = 200
|
chunk_overlap: int = 200
|
||||||
chunks: List[str] = Field(default_factory=list)
|
chunks: list[str] = Field(default_factory=list)
|
||||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
storage: KnowledgeStorage | None = Field(default=None)
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||||
collection_name: Optional[str] = Field(default=None)
|
collection_name: str | None = Field(default=None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_content(self) -> Any:
|
def validate_content(self) -> Any:
|
||||||
"""Load and preprocess content from the source."""
|
"""Load and preprocess content from the source."""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add(self) -> None:
|
def add(self) -> None:
|
||||||
"""Process content, chunk it, compute embeddings, and save them."""
|
"""Process content, chunk it, compute embeddings, and save them."""
|
||||||
pass
|
|
||||||
|
|
||||||
def get_embeddings(self) -> List[np.ndarray]:
|
def get_embeddings(self) -> list[np.ndarray]:
|
||||||
"""Return the list of embeddings for the chunks."""
|
"""Return the list of embeddings for the chunks."""
|
||||||
return self.chunk_embeddings
|
return self.chunk_embeddings
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, List, Optional, Union
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -35,11 +35,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
|
|
||||||
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
file_path: list[Path | str] | None = Field(default=None)
|
||||||
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
file_paths: list[Path | str] = Field(default_factory=list)
|
||||||
chunks: List[str] = Field(default_factory=list)
|
chunks: list[str] = Field(default_factory=list)
|
||||||
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||||
content: List["DoclingDocument"] = Field(default_factory=list)
|
content: list["DoclingDocument"] = Field(default_factory=list)
|
||||||
document_converter: "DocumentConverter" = Field(
|
document_converter: "DocumentConverter" = Field(
|
||||||
default_factory=lambda: DocumentConverter(
|
default_factory=lambda: DocumentConverter(
|
||||||
allowed_formats=[
|
allowed_formats=[
|
||||||
@@ -66,7 +66,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
self.safe_file_paths = self.validate_content()
|
self.safe_file_paths = self.validate_content()
|
||||||
self.content = self._load_content()
|
self.content = self._load_content()
|
||||||
|
|
||||||
def _load_content(self) -> List["DoclingDocument"]:
|
def _load_content(self) -> list["DoclingDocument"]:
|
||||||
try:
|
try:
|
||||||
return self._convert_source_to_docling_documents()
|
return self._convert_source_to_docling_documents()
|
||||||
except ConversionError as e:
|
except ConversionError as e:
|
||||||
@@ -88,7 +88,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(list(new_chunks_iterable))
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
|
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
|
||||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||||
return [result.document for result in conv_results_iter]
|
return [result.document for result in conv_results_iter]
|
||||||
|
|
||||||
@@ -97,8 +97,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
for chunk in chunker.chunk(doc):
|
for chunk in chunker.chunk(doc):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
|
||||||
def validate_content(self) -> List[Union[Path, str]]:
|
def validate_content(self) -> list[Path | str]:
|
||||||
processed_paths: List[Union[Path, str]] = []
|
processed_paths: list[Path | str] = []
|
||||||
for path in self.file_paths:
|
for path in self.file_paths:
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
if path.startswith(("http://", "https://")):
|
if path.startswith(("http://", "https://")):
|
||||||
@@ -108,7 +108,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid URL format: {path}")
|
raise ValueError(f"Invalid URL format: {path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
raise ValueError(f"Invalid URL: {path}. Error: {e!s}") from e
|
||||||
else:
|
else:
|
||||||
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import csv
|
import csv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -8,7 +7,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
|||||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||||
|
|
||||||
def load_content(self) -> Dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess CSV file content."""
|
"""Load and preprocess CSV file content."""
|
||||||
content_dict = {}
|
content_dict = {}
|
||||||
for file_path in self.safe_file_paths:
|
for file_path in self.safe_file_paths:
|
||||||
@@ -32,7 +31,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterator, List, Optional, Union
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -16,19 +14,19 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
|
|
||||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||||
)
|
)
|
||||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||||
default_factory=list, description="The path to the file"
|
default_factory=list, description="The path to the file"
|
||||||
)
|
)
|
||||||
chunks: List[str] = Field(default_factory=list)
|
chunks: list[str] = Field(default_factory=list)
|
||||||
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
|
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
|
||||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info):
|
def validate_file_path(self, v, info):
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -41,7 +39,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
raise ValueError("Either file_path or file_paths must be provided")
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def _process_file_paths(self) -> List[Path]:
|
def _process_file_paths(self) -> list[Path]:
|
||||||
"""Convert file_path to a list of Path objects."""
|
"""Convert file_path to a list of Path objects."""
|
||||||
|
|
||||||
if hasattr(self, "file_path") and self.file_path is not None:
|
if hasattr(self, "file_path") and self.file_path is not None:
|
||||||
@@ -56,7 +54,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
raise ValueError("Your source must be provided with a file_paths: []")
|
raise ValueError("Your source must be provided with a file_paths: []")
|
||||||
|
|
||||||
# Convert single path to list
|
# Convert single path to list
|
||||||
path_list: List[Union[Path, str]] = (
|
path_list: list[Path | str] = (
|
||||||
[self.file_paths]
|
[self.file_paths]
|
||||||
if isinstance(self.file_paths, (str, Path))
|
if isinstance(self.file_paths, (str, Path))
|
||||||
else list(self.file_paths)
|
else list(self.file_paths)
|
||||||
@@ -100,7 +98,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.validate_content()
|
self.validate_content()
|
||||||
self.content = self._load_content()
|
self.content = self._load_content()
|
||||||
|
|
||||||
def _load_content(self) -> Dict[Path, Dict[str, str]]:
|
def _load_content(self) -> dict[Path, dict[str, str]]:
|
||||||
"""Load and preprocess Excel file content from multiple sheets.
|
"""Load and preprocess Excel file content from multiple sheets.
|
||||||
|
|
||||||
Each sheet's content is converted to CSV format and stored.
|
Each sheet's content is converted to CSV format and stored.
|
||||||
@@ -126,7 +124,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
content_dict[file_path] = sheet_dict
|
content_dict[file_path] = sheet_dict
|
||||||
return content_dict
|
return content_dict
|
||||||
|
|
||||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
def convert_to_path(self, path: Path | str) -> Path:
|
||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|
||||||
@@ -161,7 +159,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -8,9 +8,9 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
|||||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||||
|
|
||||||
def load_content(self) -> Dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess JSON file content."""
|
"""Load and preprocess JSON file content."""
|
||||||
content: Dict[Path, str] = {}
|
content: dict[Path, str] = {}
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
path = self.convert_to_path(path)
|
path = self.convert_to_path(path)
|
||||||
with open(path, "r", encoding="utf-8") as json_file:
|
with open(path, "r", encoding="utf-8") as json_file:
|
||||||
@@ -29,7 +29,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
for item in data:
|
for item in data:
|
||||||
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
||||||
else:
|
else:
|
||||||
text += f"{str(data)}"
|
text += f"{data!s}"
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def add(self) -> None:
|
def add(self) -> None:
|
||||||
@@ -44,7 +44,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
|||||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||||
|
|
||||||
def load_content(self) -> Dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess PDF file content."""
|
"""Load and preprocess PDF file content."""
|
||||||
pdfplumber = self._import_pdfplumber()
|
pdfplumber = self._import_pdfplumber()
|
||||||
|
|
||||||
@@ -40,12 +39,12 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||||
and save the embeddings.
|
and save the embeddings.
|
||||||
"""
|
"""
|
||||||
for _, text in self.content.items():
|
for text in self.content.values():
|
||||||
new_chunks = self._chunk_text(text)
|
new_chunks = self._chunk_text(text)
|
||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
@@ -9,7 +7,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||||
|
|
||||||
content: str = Field(...)
|
content: str = Field(...)
|
||||||
collection_name: Optional[str] = Field(default=None)
|
collection_name: str | None = Field(default=None)
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to validate content."""
|
"""Post-initialization method to validate content."""
|
||||||
@@ -26,7 +24,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
|||||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||||
|
|
||||||
def load_content(self) -> Dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess text file content."""
|
"""Load and preprocess text file content."""
|
||||||
content = {}
|
content = {}
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
@@ -21,12 +20,12 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||||
and save the embeddings.
|
and save the embeddings.
|
||||||
"""
|
"""
|
||||||
for _, text in self.content.items():
|
for text in self.content.values():
|
||||||
new_chunks = self._chunk_text(text)
|
new_chunks = self._chunk_text(text)
|
||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
text[i : i + self.chunk_size]
|
text[i : i + self.chunk_size]
|
||||||
|
|||||||
@@ -1,21 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Self
|
from typing import Self
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -24,11 +17,12 @@ except ImportError:
|
|||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
model_validator,
|
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -39,12 +33,18 @@ from crewai.agents.parser import (
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
LiteAgentExecutionCompletedEvent,
|
||||||
|
LiteAgentExecutionErrorEvent,
|
||||||
|
LiteAgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||||
from crewai.flow.flow_trackable import FlowTrackable
|
from crewai.flow.flow_trackable import FlowTrackable
|
||||||
from crewai.llm import LLM, BaseLLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.utilities import I18N
|
from crewai.utilities import I18N
|
||||||
from crewai.utilities.guardrail import process_guardrail
|
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
enforce_rpm_limit,
|
enforce_rpm_limit,
|
||||||
format_message_for_llm,
|
format_message_for_llm,
|
||||||
@@ -62,14 +62,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
)
|
)
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
from crewai.utilities.guardrail import process_guardrail
|
||||||
from crewai.events.types.agent_events import (
|
|
||||||
LiteAgentExecutionCompletedEvent,
|
|
||||||
LiteAgentExecutionErrorEvent,
|
|
||||||
LiteAgentExecutionStartedEvent,
|
|
||||||
)
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
|
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
@@ -79,18 +72,18 @@ from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
|||||||
class LiteAgentOutput(BaseModel):
|
class LiteAgentOutput(BaseModel):
|
||||||
"""Class that represents the result of a LiteAgent execution."""
|
"""Class that represents the result of a LiteAgent execution."""
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
raw: str = Field(description="Raw output of the agent", default="")
|
raw: str = Field(description="Raw output of the agent", default="")
|
||||||
pydantic: Optional[BaseModel] = Field(
|
pydantic: BaseModel | None = Field(
|
||||||
description="Pydantic output of the agent", default=None
|
description="Pydantic output of the agent", default=None
|
||||||
)
|
)
|
||||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
usage_metrics: dict[str, Any] | None = Field(
|
||||||
description="Token usage metrics for this execution", default=None
|
description="Token usage metrics for this execution", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert pydantic_output to a dictionary."""
|
"""Convert pydantic_output to a dictionary."""
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
return self.pydantic.model_dump()
|
return self.pydantic.model_dump()
|
||||||
@@ -123,17 +116,17 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
response_format: Optional Pydantic model for structured output.
|
response_format: Optional Pydantic model for structured output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
# Core Agent Properties
|
# Core Agent Properties
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Goal of the agent")
|
goal: str = Field(description="Goal of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||||
default=None, description="Language model that will run the agent"
|
default=None, description="Language model that will run the agent"
|
||||||
)
|
)
|
||||||
tools: List[BaseTool] = Field(
|
tools: list[BaseTool] = Field(
|
||||||
default_factory=list, description="Tools at agent's disposal"
|
default_factory=list, description="Tools at agent's disposal"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,7 +134,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
max_iterations: int = Field(
|
max_iterations: int = Field(
|
||||||
default=15, description="Maximum number of iterations for tool usage"
|
default=15, description="Maximum number of iterations for tool usage"
|
||||||
)
|
)
|
||||||
max_execution_time: Optional[int] = Field(
|
max_execution_time: int | None = Field(
|
||||||
default=None, description=". Maximum execution time in seconds"
|
default=None, description=". Maximum execution time in seconds"
|
||||||
)
|
)
|
||||||
respect_context_window: bool = Field(
|
respect_context_window: bool = Field(
|
||||||
@@ -152,52 +145,50 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
default=True,
|
default=True,
|
||||||
description="Whether to use stop words to prevent the LLM from using tools",
|
description="Whether to use stop words to prevent the LLM from using tools",
|
||||||
)
|
)
|
||||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Callback to check if the request is within the RPM limit",
|
description="Callback to check if the request is within the RPM limit",
|
||||||
)
|
)
|
||||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||||
|
|
||||||
# Output and Formatting Properties
|
# Output and Formatting Properties
|
||||||
response_format: Optional[Type[BaseModel]] = Field(
|
response_format: type[BaseModel] | None = Field(
|
||||||
default=None, description="Pydantic model for structured output"
|
default=None, description="Pydantic model for structured output"
|
||||||
)
|
)
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=False, description="Whether to print execution details"
|
default=False, description="Whether to print execution details"
|
||||||
)
|
)
|
||||||
callbacks: List[Callable] = Field(
|
callbacks: list[Callable] = Field(
|
||||||
default=[], description="Callbacks to be used for the agent"
|
default=[], description="Callbacks to be used for the agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Guardrail Properties
|
# Guardrail Properties
|
||||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
guardrail: Callable[[LiteAgentOutput], tuple[bool, Any]] | str | None = Field(
|
||||||
Field(
|
default=None,
|
||||||
default=None,
|
description="Function or string description of a guardrail to validate agent output",
|
||||||
description="Function or string description of a guardrail to validate agent output",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
guardrail_max_retries: int = Field(
|
guardrail_max_retries: int = Field(
|
||||||
default=3, description="Maximum number of retries when guardrail fails"
|
default=3, description="Maximum number of retries when guardrail fails"
|
||||||
)
|
)
|
||||||
|
|
||||||
# State and Results
|
# State and Results
|
||||||
tools_results: List[Dict[str, Any]] = Field(
|
tools_results: list[dict[str, Any]] = Field(
|
||||||
default=[], description="Results of the tools used by the agent."
|
default=[], description="Results of the tools used by the agent."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference of Agent
|
# Reference of Agent
|
||||||
original_agent: Optional[BaseAgent] = Field(
|
original_agent: BaseAgent | None = Field(
|
||||||
default=None, description="Reference to the agent that created this LiteAgent"
|
default=None, description="Reference to the agent that created this LiteAgent"
|
||||||
)
|
)
|
||||||
# Private Attributes
|
# Private Attributes
|
||||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||||
_iterations: int = PrivateAttr(default=0)
|
_iterations: int = PrivateAttr(default=0)
|
||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -241,8 +232,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
@field_validator("guardrail", mode="before")
|
@field_validator("guardrail", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_guardrail_function(
|
def validate_guardrail_function(
|
||||||
cls, v: Optional[Union[Callable, str]]
|
cls, v: Callable | str | None
|
||||||
) -> Optional[Union[Callable, str]]:
|
) -> Callable | str | None:
|
||||||
"""Validate that the guardrail function has the correct signature.
|
"""Validate that the guardrail function has the correct signature.
|
||||||
|
|
||||||
If v is a callable, validate that it has the correct signature.
|
If v is a callable, validate that it has the correct signature.
|
||||||
@@ -267,7 +258,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
# Check return annotation if present
|
# Check return annotation if present
|
||||||
if sig.return_annotation is not sig.empty:
|
if sig.return_annotation is not sig.empty:
|
||||||
if sig.return_annotation == Tuple[bool, Any]:
|
if sig.return_annotation == tuple[bool, Any]:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
origin = get_origin(sig.return_annotation)
|
origin = get_origin(sig.return_annotation)
|
||||||
@@ -290,7 +281,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
"""Return the original role for compatibility with tool interfaces."""
|
"""Return the original role for compatibility with tool interfaces."""
|
||||||
return self.role
|
return self.role
|
||||||
|
|
||||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent with the given messages.
|
Execute the agent with the given messages.
|
||||||
|
|
||||||
@@ -338,7 +329,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||||
# Emit event for agent execution start
|
# Emit event for agent execution start
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -351,7 +342,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
# Execute the agent using invoke loop
|
# Execute the agent using invoke loop
|
||||||
agent_finish = self._invoke_loop()
|
agent_finish = self._invoke_loop()
|
||||||
formatted_result: Optional[BaseModel] = None
|
formatted_result: BaseModel | None = None
|
||||||
if self.response_format:
|
if self.response_format:
|
||||||
try:
|
try:
|
||||||
# Cast to BaseModel to ensure type safety
|
# Cast to BaseModel to ensure type safety
|
||||||
@@ -360,7 +351,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
formatted_result = result
|
formatted_result = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
content=f"Failed to parse output into response format: {str(e)}",
|
content=f"Failed to parse output into response format: {e!s}",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -428,7 +419,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
async def kickoff_async(
|
async def kickoff_async(
|
||||||
self, messages: Union[str, List[Dict[str, str]]]
|
self, messages: str | list[dict[str, str]]
|
||||||
) -> LiteAgentOutput:
|
) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent asynchronously with the given messages.
|
Execute the agent asynchronously with the given messages.
|
||||||
@@ -475,8 +466,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def _format_messages(
|
def _format_messages(
|
||||||
self, messages: Union[str, List[Dict[str, str]]]
|
self, messages: str | list[dict[str, str]]
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Format messages for the LLM."""
|
"""Format messages for the LLM."""
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
@@ -571,9 +562,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
i18n=self.i18n,
|
i18n=self.i18n,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
handle_unknown_error(self._printer, e)
|
||||||
handle_unknown_error(self._printer, e)
|
raise e
|
||||||
raise e
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._iterations += 1
|
self._iterations += 1
|
||||||
@@ -582,7 +572,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
def _show_logs(self, formatted_answer: AgentAction | AgentFinish):
|
||||||
"""Show logs for the agent's execution."""
|
"""Show logs for the agent's execution."""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -6,19 +6,14 @@ import threading
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from datetime import datetime
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
|
|||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.tool_usage_events import (
|
from crewai.events.types.tool_usage_events import (
|
||||||
ToolUsageStartedEvent,
|
|
||||||
ToolUsageFinishedEvent,
|
|
||||||
ToolUsageErrorEvent,
|
ToolUsageErrorEvent,
|
||||||
|
ToolUsageFinishedEvent,
|
||||||
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
|
|||||||
import io
|
import io
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
@@ -268,14 +263,14 @@ def suppress_warnings():
|
|||||||
|
|
||||||
|
|
||||||
class Delta(TypedDict):
|
class Delta(TypedDict):
|
||||||
content: Optional[str]
|
content: str | None
|
||||||
role: Optional[str]
|
role: str | None
|
||||||
|
|
||||||
|
|
||||||
class StreamingChoices(TypedDict):
|
class StreamingChoices(TypedDict):
|
||||||
delta: Delta
|
delta: Delta
|
||||||
index: int
|
index: int
|
||||||
finish_reason: Optional[str]
|
finish_reason: str | None
|
||||||
|
|
||||||
|
|
||||||
class FunctionArgs(BaseModel):
|
class FunctionArgs(BaseModel):
|
||||||
@@ -288,31 +283,31 @@ class AccumulatedToolArgs(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLM(BaseLLM):
|
class LLM(BaseLLM):
|
||||||
completion_cost: Optional[float] = None
|
completion_cost: float | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: float | int | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: dict[int, float] | None = None,
|
||||||
response_format: Optional[Type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: int | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: str | None = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: str | None = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: str | None = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: str | None = None,
|
||||||
callbacks: List[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -345,7 +340,7 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
# Normalize self.stop to always be a List[str]
|
# Normalize self.stop to always be a List[str]
|
||||||
if stop is None:
|
if stop is None:
|
||||||
self.stop: List[str] = []
|
self.stop: list[str] = []
|
||||||
elif isinstance(stop, str):
|
elif isinstance(stop, str):
|
||||||
self.stop = [stop]
|
self.stop = [stop]
|
||||||
else:
|
else:
|
||||||
@@ -368,9 +363,9 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _prepare_completion_params(
|
def _prepare_completion_params(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare parameters for the completion call.
|
"""Prepare parameters for the completion call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -419,11 +414,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_streaming_response(
|
def _handle_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle a streaming response from the LLM.
|
"""Handle a streaming response from the LLM.
|
||||||
|
|
||||||
@@ -447,7 +442,7 @@ class LLM(BaseLLM):
|
|||||||
usage_info = None
|
usage_info = None
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
|
|
||||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||||
AccumulatedToolArgs
|
AccumulatedToolArgs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -472,16 +467,16 @@ class LLM(BaseLLM):
|
|||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
elif hasattr(chunk, "choices"):
|
elif hasattr(chunk, "choices"):
|
||||||
# Check if choices is not a type but an actual attribute with value
|
# Check if choices is not a type but an actual attribute with value
|
||||||
if not isinstance(getattr(chunk, "choices"), type):
|
if not isinstance(chunk.choices, type):
|
||||||
choices = getattr(chunk, "choices")
|
choices = chunk.choices
|
||||||
|
|
||||||
# Try to extract usage information if available
|
# Try to extract usage information if available
|
||||||
if isinstance(chunk, dict) and "usage" in chunk:
|
if isinstance(chunk, dict) and "usage" in chunk:
|
||||||
usage_info = chunk["usage"]
|
usage_info = chunk["usage"]
|
||||||
elif hasattr(chunk, "usage"):
|
elif hasattr(chunk, "usage"):
|
||||||
# Check if usage is not a type but an actual attribute with value
|
# Check if usage is not a type but an actual attribute with value
|
||||||
if not isinstance(getattr(chunk, "usage"), type):
|
if not isinstance(chunk.usage, type):
|
||||||
usage_info = getattr(chunk, "usage")
|
usage_info = chunk.usage
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -491,7 +486,7 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "delta" in choice:
|
if isinstance(choice, dict) and "delta" in choice:
|
||||||
delta = choice["delta"]
|
delta = choice["delta"]
|
||||||
elif hasattr(choice, "delta"):
|
elif hasattr(choice, "delta"):
|
||||||
delta = getattr(choice, "delta")
|
delta = choice.delta
|
||||||
|
|
||||||
# Extract content from delta
|
# Extract content from delta
|
||||||
if delta:
|
if delta:
|
||||||
@@ -501,7 +496,7 @@ class LLM(BaseLLM):
|
|||||||
chunk_content = delta["content"]
|
chunk_content = delta["content"]
|
||||||
# Handle object format
|
# Handle object format
|
||||||
elif hasattr(delta, "content"):
|
elif hasattr(delta, "content"):
|
||||||
chunk_content = getattr(delta, "content")
|
chunk_content = delta.content
|
||||||
|
|
||||||
# Handle case where content might be None or empty
|
# Handle case where content might be None or empty
|
||||||
if chunk_content is None and isinstance(delta, dict):
|
if chunk_content is None and isinstance(delta, dict):
|
||||||
@@ -572,8 +567,8 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||||
choices = last_chunk["choices"]
|
choices = last_chunk["choices"]
|
||||||
elif hasattr(last_chunk, "choices"):
|
elif hasattr(last_chunk, "choices"):
|
||||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
if not isinstance(last_chunk.choices, type):
|
||||||
choices = getattr(last_chunk, "choices")
|
choices = last_chunk.choices
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -583,14 +578,14 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "message" in choice:
|
if isinstance(choice, dict) and "message" in choice:
|
||||||
message = choice["message"]
|
message = choice["message"]
|
||||||
elif hasattr(choice, "message"):
|
elif hasattr(choice, "message"):
|
||||||
message = getattr(choice, "message")
|
message = choice.message
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
content = None
|
content = None
|
||||||
if isinstance(message, dict) and "content" in message:
|
if isinstance(message, dict) and "content" in message:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
elif hasattr(message, "content"):
|
elif hasattr(message, "content"):
|
||||||
content = getattr(message, "content")
|
content = message.content
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
full_response = content
|
full_response = content
|
||||||
@@ -617,8 +612,8 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||||
choices = last_chunk["choices"]
|
choices = last_chunk["choices"]
|
||||||
elif hasattr(last_chunk, "choices"):
|
elif hasattr(last_chunk, "choices"):
|
||||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
if not isinstance(last_chunk.choices, type):
|
||||||
choices = getattr(last_chunk, "choices")
|
choices = last_chunk.choices
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -627,13 +622,13 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "message" in choice:
|
if isinstance(choice, dict) and "message" in choice:
|
||||||
message = choice["message"]
|
message = choice["message"]
|
||||||
elif hasattr(choice, "message"):
|
elif hasattr(choice, "message"):
|
||||||
message = getattr(choice, "message")
|
message = choice.message
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
if isinstance(message, dict) and "tool_calls" in message:
|
if isinstance(message, dict) and "tool_calls" in message:
|
||||||
tool_calls = message["tool_calls"]
|
tool_calls = message["tool_calls"]
|
||||||
elif hasattr(message, "tool_calls"):
|
elif hasattr(message, "tool_calls"):
|
||||||
tool_calls = getattr(message, "tool_calls")
|
tool_calls = message.tool_calls
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error checking for tool calls: {e}")
|
logging.debug(f"Error checking for tool calls: {e}")
|
||||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||||
@@ -675,9 +670,9 @@ class LLM(BaseLLM):
|
|||||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||||
raise LLMContextLengthExceededException(str(e))
|
raise LLMContextLengthExceededException(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in streaming response: {str(e)}")
|
logging.error(f"Error in streaming response: {e!s}")
|
||||||
if full_response.strip():
|
if full_response.strip():
|
||||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -695,15 +690,15 @@ class LLM(BaseLLM):
|
|||||||
error=str(e), from_task=from_task, from_agent=from_agent
|
error=str(e), from_task=from_task, from_agent=from_agent
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
raise Exception(f"Failed to get streaming response: {e!s}")
|
||||||
|
|
||||||
def _handle_streaming_tool_calls(
|
def _handle_streaming_tool_calls(
|
||||||
self,
|
self,
|
||||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> None | str:
|
) -> None | str:
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||||
@@ -744,9 +739,9 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_streaming_callbacks(
|
def _handle_streaming_callbacks(
|
||||||
self,
|
self,
|
||||||
callbacks: Optional[List[Any]],
|
callbacks: list[Any] | None,
|
||||||
usage_info: Optional[Dict[str, Any]],
|
usage_info: dict[str, Any] | None,
|
||||||
last_chunk: Optional[Any],
|
last_chunk: Any | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle callbacks with usage info for streaming responses.
|
"""Handle callbacks with usage info for streaming responses.
|
||||||
|
|
||||||
@@ -769,10 +764,8 @@ class LLM(BaseLLM):
|
|||||||
):
|
):
|
||||||
usage_info = last_chunk["usage"]
|
usage_info = last_chunk["usage"]
|
||||||
elif hasattr(last_chunk, "usage"):
|
elif hasattr(last_chunk, "usage"):
|
||||||
if not isinstance(
|
if not isinstance(last_chunk.usage, type):
|
||||||
getattr(last_chunk, "usage"), type
|
usage_info = last_chunk.usage
|
||||||
):
|
|
||||||
usage_info = getattr(last_chunk, "usage")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error extracting usage info: {e}")
|
logging.debug(f"Error extracting usage info: {e}")
|
||||||
|
|
||||||
@@ -786,11 +779,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_non_streaming_response(
|
def _handle_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle a non-streaming response from the LLM.
|
"""Handle a non-streaming response from the LLM.
|
||||||
|
|
||||||
@@ -847,7 +840,7 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
return text_response
|
return text_response
|
||||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||||
elif tool_calls and not available_functions and not text_response:
|
if tool_calls and not available_functions and not text_response:
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
# --- 7) Handle tool calls if present
|
# --- 7) Handle tool calls if present
|
||||||
@@ -868,11 +861,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_tool_call(
|
def _handle_tool_call(
|
||||||
self,
|
self,
|
||||||
tool_calls: List[Any],
|
tool_calls: list[Any],
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
"""Handle a tool call from the LLM.
|
"""Handle a tool call from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -942,14 +935,14 @@ class LLM(BaseLLM):
|
|||||||
assert hasattr(crewai_event_bus, "emit")
|
assert hasattr(crewai_event_bus, "emit")
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||||
)
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=ToolUsageErrorEvent(
|
event=ToolUsageErrorEvent(
|
||||||
tool_name=function_name,
|
tool_name=function_name,
|
||||||
tool_args=function_args,
|
tool_args=function_args,
|
||||||
error=f"Tool execution error: {str(e)}",
|
error=f"Tool execution error: {e!s}",
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
@@ -958,13 +951,13 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
"""High-level LLM call method.
|
"""High-level LLM call method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1028,10 +1021,9 @@ class LLM(BaseLLM):
|
|||||||
return self._handle_streaming_response(
|
return self._handle_streaming_response(
|
||||||
params, callbacks, available_functions, from_task, from_agent
|
params, callbacks, available_functions, from_task, from_agent
|
||||||
)
|
)
|
||||||
else:
|
return self._handle_non_streaming_response(
|
||||||
return self._handle_non_streaming_response(
|
params, callbacks, available_functions, from_task, from_agent
|
||||||
params, callbacks, available_functions, from_task, from_agent
|
)
|
||||||
)
|
|
||||||
|
|
||||||
except LLMContextLengthExceededException:
|
except LLMContextLengthExceededException:
|
||||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||||
@@ -1078,8 +1070,8 @@ class LLM(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
response: Any,
|
response: Any,
|
||||||
call_type: LLMCallType,
|
call_type: LLMCallType,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
messages: str | list[dict[str, Any]] | None = None,
|
messages: str | list[dict[str, Any]] | None = None,
|
||||||
):
|
):
|
||||||
"""Handle the events for the LLM call.
|
"""Handle the events for the LLM call.
|
||||||
@@ -1105,8 +1097,8 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _format_messages_for_provider(
|
def _format_messages_for_provider(
|
||||||
self, messages: List[Dict[str, str]]
|
self, messages: list[dict[str, str]]
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Format messages according to provider requirements.
|
"""Format messages according to provider requirements.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1147,7 +1139,7 @@ class LLM(BaseLLM):
|
|||||||
if "mistral" in self.model.lower():
|
if "mistral" in self.model.lower():
|
||||||
# Check if the last message has a role of 'assistant'
|
# Check if the last message has a role of 'assistant'
|
||||||
if messages and messages[-1]["role"] == "assistant":
|
if messages and messages[-1]["role"] == "assistant":
|
||||||
return messages + [{"role": "user", "content": "Please continue."}]
|
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||||
@@ -1157,7 +1149,7 @@ class LLM(BaseLLM):
|
|||||||
and messages
|
and messages
|
||||||
and messages[-1]["role"] == "assistant"
|
and messages[-1]["role"] == "assistant"
|
||||||
):
|
):
|
||||||
return messages + [{"role": "user", "content": ""}]
|
return [*messages, {"role": "user", "content": ""}]
|
||||||
|
|
||||||
# Handle Anthropic models
|
# Handle Anthropic models
|
||||||
if not self.is_anthropic:
|
if not self.is_anthropic:
|
||||||
@@ -1170,7 +1162,7 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
def _get_custom_llm_provider(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Derives the custom_llm_provider from the model string.
|
Derives the custom_llm_provider from the model string.
|
||||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||||
@@ -1207,7 +1199,7 @@ class LLM(BaseLLM):
|
|||||||
self.model, custom_llm_provider=provider
|
self.model, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
logging.error(f"Failed to check function calling support: {e!s}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
@@ -1215,7 +1207,7 @@ class LLM(BaseLLM):
|
|||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return params is not None and "stop" in params
|
return params is not None and "stop" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
logging.error(f"Failed to get supported params: {e!s}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
@@ -1247,7 +1239,7 @@ class LLM(BaseLLM):
|
|||||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
def set_callbacks(self, callbacks: List[Any]):
|
def set_callbacks(self, callbacks: list[Any]):
|
||||||
"""
|
"""
|
||||||
Attempt to keep a single set of callbacks in litellm by removing old
|
Attempt to keep a single set of callbacks in litellm by removing old
|
||||||
duplicates and adding new ones.
|
duplicates and adding new ones.
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from .entity.entity_memory import EntityMemory
|
from .entity.entity_memory import EntityMemory
|
||||||
|
from .external.external_memory import ExternalMemory
|
||||||
from .long_term.long_term_memory import LongTermMemory
|
from .long_term.long_term_memory import LongTermMemory
|
||||||
from .short_term.short_term_memory import ShortTermMemory
|
from .short_term.short_term_memory import ShortTermMemory
|
||||||
from .external.external_memory import ExternalMemory
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EntityMemory",
|
"EntityMemory",
|
||||||
|
"ExternalMemory",
|
||||||
"LongTermMemory",
|
"LongTermMemory",
|
||||||
"ShortTermMemory",
|
"ShortTermMemory",
|
||||||
"ExternalMemory",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ExternalMemoryItem:
|
class ExternalMemoryItem:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
agent: Optional[str] = None,
|
agent: str | None = None,
|
||||||
):
|
):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from typing import Any, Dict, List
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
|
||||||
from crewai.memory.memory import Memory
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
MemoryQueryCompletedEvent,
|
||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
MemorySaveStartedEvent,
|
MemoryQueryStartedEvent,
|
||||||
MemorySaveCompletedEvent,
|
MemorySaveCompletedEvent,
|
||||||
MemorySaveFailedEvent,
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
|
|||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
latest_n: int = 3,
|
latest_n: int = 3,
|
||||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class LongTermMemoryItem:
|
class LongTermMemoryItem:
|
||||||
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
|
|||||||
task: str,
|
task: str,
|
||||||
expected_output: str,
|
expected_output: str,
|
||||||
datetime: str,
|
datetime: str,
|
||||||
quality: Optional[Union[int, float]] = None,
|
quality: int | float | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ShortTermMemoryItem:
|
class ShortTermMemoryItem:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: Any,
|
data: Any,
|
||||||
agent: Optional[str] = None,
|
agent: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
"""Abstract base class defining the storage interface"""
|
"""Abstract base class defining the storage interface"""
|
||||||
|
|
||||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: str, limit: int, score_threshold: float
|
self, query: str, limit: int, score_threshold: float
|
||||||
) -> Dict[str, Any] | List[Any]:
|
) -> dict[str, Any] | list[Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
An updated SQLite storage class for kickoff task outputs storage.
|
An updated SQLite storage class for kickoff task outputs storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
# Get the parent directory of the default db path and create our db file there
|
||||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||||
@@ -62,10 +62,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
output: Dict[str, Any],
|
output: dict[str, Any],
|
||||||
task_index: int,
|
task_index: int,
|
||||||
was_replayed: bool = False,
|
was_replayed: bool = False,
|
||||||
inputs: Dict[str, Any] | None = None,
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a new task output record to the database.
|
"""Add a new task output record to the database.
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e)
|
||||||
|
|
||||||
def load(self) -> List[Dict[str, Any]]:
|
def load(self) -> list[dict[str, Any]]:
|
||||||
"""Load all task output records from the database.
|
"""Load all task output records from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
|
|||||||
An updated SQLite storage class for LTM data storage.
|
An updated SQLite storage class for LTM data storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
self, db_path: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
# Get the parent directory of the default db path and create our db file there
|
||||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||||
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
metadata: Dict[str, Any],
|
metadata: dict[str, Any],
|
||||||
datetime: str,
|
datetime: str,
|
||||||
score: Union[int, float],
|
score: int | float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves data to the LTM table with error handling."""
|
"""Saves data to the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(
|
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
|
||||||
self, task_description: str, latest_n: int
|
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
|
||||||
"""Queries the LTM table by task description with error handling."""
|
"""Queries the LTM table by task description with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
|
|||||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return None
|
return
|
||||||
|
|||||||
@@ -14,16 +14,16 @@ from .annotations import (
|
|||||||
from .crew_base import CrewBase
|
from .crew_base import CrewBase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"CrewBase",
|
||||||
|
"after_kickoff",
|
||||||
"agent",
|
"agent",
|
||||||
|
"before_kickoff",
|
||||||
|
"cache_handler",
|
||||||
|
"callback",
|
||||||
"crew",
|
"crew",
|
||||||
"task",
|
"llm",
|
||||||
"output_json",
|
"output_json",
|
||||||
"output_pydantic",
|
"output_pydantic",
|
||||||
|
"task",
|
||||||
"tool",
|
"tool",
|
||||||
"callback",
|
|
||||||
"CrewBase",
|
|
||||||
"llm",
|
|
||||||
"cache_handler",
|
|
||||||
"before_kickoff",
|
|
||||||
"after_kickoff",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from crewai import Crew
|
from crewai import Crew
|
||||||
from crewai.project.utils import memoize
|
from crewai.project.utils import memoize
|
||||||
@@ -36,15 +36,13 @@ def task(func):
|
|||||||
def agent(func):
|
def agent(func):
|
||||||
"""Marks a method as a crew agent."""
|
"""Marks a method as a crew agent."""
|
||||||
func.is_agent = True
|
func.is_agent = True
|
||||||
func = memoize(func)
|
return memoize(func)
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def llm(func):
|
def llm(func):
|
||||||
"""Marks a method as an LLM provider."""
|
"""Marks a method as an LLM provider."""
|
||||||
func.is_llm = True
|
func.is_llm = True
|
||||||
func = memoize(func)
|
return memoize(func)
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def output_json(cls):
|
def output_json(cls):
|
||||||
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
|
|||||||
agents = self._original_agents.items()
|
agents = self._original_agents.items()
|
||||||
|
|
||||||
# Instantiate tasks in order
|
# Instantiate tasks in order
|
||||||
for task_name, task_method in tasks:
|
for _task_name, task_method in tasks:
|
||||||
task_instance = task_method(self)
|
task_instance = task_method(self)
|
||||||
instantiated_tasks.append(task_instance)
|
instantiated_tasks.append(task_instance)
|
||||||
agent_instance = getattr(task_instance, "agent", None)
|
agent_instance = getattr(task_instance, "agent", None)
|
||||||
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
|
|||||||
agent_roles.add(agent_instance.role)
|
agent_roles.add(agent_instance.role)
|
||||||
|
|
||||||
# Instantiate agents not included by tasks
|
# Instantiate agents not included by tasks
|
||||||
for agent_name, agent_method in agents:
|
for _agent_name, agent_method in agents:
|
||||||
agent_instance = agent_method(self)
|
agent_instance = agent_method(self)
|
||||||
if agent_instance.role not in agent_roles:
|
if agent_instance.role not in agent_roles:
|
||||||
instantiated_agents.append(agent_instance)
|
instantiated_agents.append(agent_instance)
|
||||||
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
for _, callback in self._before_kickoff.items():
|
for callback in self._before_kickoff.values():
|
||||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||||
for _, callback in self._after_kickoff.items():
|
for callback in self._after_kickoff.values():
|
||||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||||
|
|
||||||
return crew
|
return crew
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||||
|
|
||||||
import sys
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
from crewai.rag.config.utils import set_rag_config
|
from crewai.rag.config.utils import set_rag_config
|
||||||
|
|
||||||
|
|
||||||
_module_path = __path__
|
_module_path = __path__
|
||||||
_module_file = __file__
|
_module_file = __file__
|
||||||
|
|
||||||
|
|
||||||
class _RagModule(ModuleType):
|
class _RagModule(ModuleType):
|
||||||
"""Module wrapper to intercept attribute setting for config."""
|
"""Module wrapper to intercept attribute setting for config."""
|
||||||
|
|
||||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return importlib.import_module(f"{self.__name__}.{name}")
|
return importlib.import_module(f"{self.__name__}.{name}")
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
raise AttributeError(
|
||||||
|
f"module '{self.__name__}' has no attribute '{name}'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
sys.modules[__name__] = _RagModule(__name__)
|
sys.modules[__name__] = _RagModule(__name__)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Optional imports for RAG configuration providers."""
|
"""Optional imports for RAG configuration providers."""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Base classes for missing provider configurations."""
|
"""Base classes for missing provider configurations."""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user