mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
Merge branch 'crewAIInc:main' into main
This commit is contained in:
@@ -14,7 +14,7 @@ warnings.filterwarnings(
|
||||
category=UserWarning,
|
||||
module="pydantic.main",
|
||||
)
|
||||
__version__ = "0.98.0"
|
||||
__version__ = "0.100.1"
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Crew",
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
@@ -17,7 +16,6 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -56,13 +54,13 @@ class Agent(BaseAgent):
|
||||
llm: The language model that will run the agent.
|
||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
||||
memory: Whether the agent should have memory or not.
|
||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
||||
verbose: Whether the agent execution should be in verbose mode.
|
||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
||||
tools: Tools at agents disposal
|
||||
step_callback: Callback to be executed after each step of the agent execution.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
embedder: Embedder configuration for the agent.
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
@@ -72,9 +70,6 @@ class Agent(BaseAgent):
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
@@ -108,10 +103,6 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Keep messages under the context window size by summarizing content.",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=20,
|
||||
description="Maximum number of iterations for an agent to execute a task before giving it's best answer",
|
||||
)
|
||||
max_retry_limit: int = Field(
|
||||
default=2,
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
@@ -124,17 +115,10 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
embedder_config: Optional[Dict[str, Any]] = Field(
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
@@ -161,14 +145,16 @@ class Agent(BaseAgent):
|
||||
def _set_knowledge(self):
|
||||
try:
|
||||
if self.knowledge_sources:
|
||||
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
|
||||
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
|
||||
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
self._knowledge = Knowledge(
|
||||
self.knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder_config=self.embedder_config,
|
||||
embedder=self.embedder,
|
||||
collection_name=knowledge_agent_name,
|
||||
storage=self.knowledge_storage or None,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
@@ -202,13 +188,15 @@ class Agent(BaseAgent):
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
|
||||
task_prompt += "\n" + self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
@@ -227,8 +215,8 @@ class Agent(BaseAgent):
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
if self._knowledge:
|
||||
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
|
||||
if agent_knowledge_snippets:
|
||||
agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
@@ -261,6 +249,9 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
raise e
|
||||
@@ -333,14 +324,14 @@ class Agent(BaseAgent):
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def get_multimodal_tools(self) -> List[Tool]:
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self):
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
|
||||
# Set the unsafe_mode based on the code_execution_mode attribute
|
||||
unsafe_mode = self.code_execution_mode == "unsafe"
|
||||
|
||||
@@ -18,10 +18,13 @@ from pydantic_core import PydanticCustomError
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
|
||||
@@ -40,7 +43,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
max_rpm (Optional[int]): Maximum number of requests per minute for the agent execution.
|
||||
allow_delegation (bool): Allow delegation of tasks to agents.
|
||||
tools (Optional[List[Any]]): Tools at the agent's disposal.
|
||||
max_iter (Optional[int]): Maximum iterations for an agent to execute a task.
|
||||
max_iter (int): Maximum iterations for an agent to execute a task.
|
||||
agent_executor (InstanceOf): An instance of the CrewAgentExecutor class.
|
||||
llm (Any): Language model that will run the agent.
|
||||
crew (Any): Crew to which the agent belongs.
|
||||
@@ -48,6 +51,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
cache_handler (InstanceOf[CacheHandler]): An instance of the CacheHandler class.
|
||||
tools_handler (InstanceOf[ToolsHandler]): An instance of the ToolsHandler class.
|
||||
max_tokens: Maximum number of tokens for the agent to generate in a response.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
knowledge_storage: Custom knowledge storage for the agent.
|
||||
|
||||
|
||||
Methods:
|
||||
@@ -114,7 +119,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
tools: Optional[List[Any]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: Optional[int] = Field(
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: InstanceOf = Field(
|
||||
@@ -125,15 +130,27 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -249,7 +266,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
@abstractmethod
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||
):
|
||||
) -> Converter:
|
||||
"""Get the converter class for the agent to create json/pydantic outputs."""
|
||||
pass
|
||||
|
||||
@@ -266,13 +283,44 @@ class BaseAgent(ABC, BaseModel):
|
||||
"tools_handler",
|
||||
"cache_handler",
|
||||
"llm",
|
||||
"knowledge_sources",
|
||||
"knowledge_storage",
|
||||
"knowledge",
|
||||
}
|
||||
|
||||
# Copy llm and clear callbacks
|
||||
# Copy llm
|
||||
existing_llm = shallow_copy(self.llm)
|
||||
copied_knowledge = shallow_copy(self.knowledge)
|
||||
copied_knowledge_storage = shallow_copy(self.knowledge_storage)
|
||||
# Properly copy knowledge sources if they exist
|
||||
existing_knowledge_sources = None
|
||||
if self.knowledge_sources:
|
||||
# Create a shared storage instance for all knowledge sources
|
||||
shared_storage = (
|
||||
self.knowledge_sources[0].storage if self.knowledge_sources else None
|
||||
)
|
||||
|
||||
existing_knowledge_sources = []
|
||||
for source in self.knowledge_sources:
|
||||
copied_source = (
|
||||
source.model_copy()
|
||||
if hasattr(source, "model_copy")
|
||||
else shallow_copy(source)
|
||||
)
|
||||
# Ensure all copied sources use the same storage instance
|
||||
copied_source.storage = shared_storage
|
||||
existing_knowledge_sources.append(copied_source)
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
|
||||
copied_agent = type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
tools=self.tools,
|
||||
knowledge_sources=existing_knowledge_sources,
|
||||
knowledge=copied_knowledge,
|
||||
knowledge_storage=copied_knowledge_storage,
|
||||
)
|
||||
|
||||
return copied_agent
|
||||
|
||||
|
||||
@@ -95,18 +95,29 @@ class CrewAgentExecutorMixin:
|
||||
pass
|
||||
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input for final decision making."""
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
)
|
||||
|
||||
self._printer.print(
|
||||
content=(
|
||||
# Training mode prompt (single iteration)
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
"## Please provide feedback on the Final Result and the Agent's actions. "
|
||||
"Respond with 'looks good' or a similar phrase when you're satisfied.\n"
|
||||
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
||||
"=====\n"
|
||||
),
|
||||
color="bold_yellow",
|
||||
)
|
||||
)
|
||||
# Regular human-in-the-loop prompt (multiple iterations)
|
||||
else:
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
||||
"Respond with 'looks good' to accept or provide specific improvement requests.\n"
|
||||
"You can provide multiple rounds of feedback until satisfied.\n"
|
||||
"=====\n"
|
||||
)
|
||||
|
||||
self._printer.print(content=prompt, color="bold_yellow")
|
||||
return input()
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.agents.parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||
from crewai.utilities import I18N, Printer
|
||||
@@ -54,7 +55,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
callbacks: List[Any] = [],
|
||||
):
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm = llm
|
||||
self.llm: LLM = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
@@ -80,10 +81,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
if self.llm.stop:
|
||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
||||
else:
|
||||
self.llm.stop = self.stop
|
||||
self.stop = stop_words
|
||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
@@ -98,7 +97,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
formatted_answer = self._invoke_loop()
|
||||
|
||||
try:
|
||||
formatted_answer = self._invoke_loop()
|
||||
except AssertionError:
|
||||
self._printer.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
else:
|
||||
self._handle_unknown_error(e)
|
||||
raise e
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
@@ -107,7 +121,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _invoke_loop(self):
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""
|
||||
Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
or the maximum number of iterations is reached.
|
||||
@@ -124,7 +138,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._enforce_rpm_limit()
|
||||
|
||||
answer = self._get_llm_response()
|
||||
|
||||
formatted_answer = self._process_llm_response(answer)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
@@ -142,13 +155,37 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
formatted_answer = self._handle_output_parser_exception(e)
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
if self._is_context_length_exceeded(e):
|
||||
self._handle_context_length()
|
||||
continue
|
||||
else:
|
||||
self._handle_unknown_error(e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
# During the invoke loop, formatted_answer alternates between AgentAction
|
||||
# (when the agent is using tools) and eventually becomes AgentFinish
|
||||
# (when the agent reaches a final answer). This assertion confirms we've
|
||||
# reached a final answer and helps type checking understand this transition.
|
||||
assert isinstance(formatted_answer, AgentFinish)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _handle_unknown_error(self, exception: Exception) -> None:
|
||||
"""Handle unknown errors by informing the user."""
|
||||
self._printer.print(
|
||||
content="An unknown error occurred. Please check the details below.",
|
||||
color="red",
|
||||
)
|
||||
self._printer.print(
|
||||
content=f"Error details: {exception}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _has_reached_max_iterations(self) -> bool:
|
||||
"""Check if the maximum number of iterations has been reached."""
|
||||
return self.iterations >= self.max_iter
|
||||
@@ -160,10 +197,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def _get_llm_response(self) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses."""
|
||||
answer = self.llm.call(
|
||||
self.messages,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
try:
|
||||
answer = self.llm.call(
|
||||
self.messages,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Error during LLM call: {e}",
|
||||
color="red",
|
||||
)
|
||||
raise e
|
||||
|
||||
if not answer:
|
||||
self._printer.print(
|
||||
@@ -184,7 +228,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||
answer = answer.split("Observation:")[0].strip()
|
||||
|
||||
self.iterations += 1
|
||||
return self._format_answer(answer)
|
||||
|
||||
def _handle_agent_action(
|
||||
@@ -260,8 +303,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
||||
)
|
||||
description = (
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
)
|
||||
self._printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{description}\033[00m"
|
||||
)
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
@@ -386,58 +432,50 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: str | None = None
|
||||
self, result: AgentFinish, human_feedback: Optional[str] = None
|
||||
) -> None:
|
||||
"""Function to handle the process of the training data."""
|
||||
"""Handle the process of saving training data."""
|
||||
agent_id = str(self.agent.id) # type: ignore
|
||||
train_iteration = (
|
||||
getattr(self.crew, "_train_iteration", None) if self.crew else None
|
||||
)
|
||||
|
||||
if train_iteration is None or not isinstance(train_iteration, int):
|
||||
self._printer.print(
|
||||
content="Invalid or missing train iteration. Cannot save training data.",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
# Load training data
|
||||
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
|
||||
training_data = training_handler.load()
|
||||
training_data = training_handler.load() or {}
|
||||
|
||||
# Check if training data exists, human input is not requested, and self.crew is valid
|
||||
if training_data and not self.ask_for_human_input:
|
||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||
train_iteration = self.crew._train_iteration
|
||||
if agent_id in training_data and isinstance(train_iteration, int):
|
||||
training_data[agent_id][train_iteration][
|
||||
"improved_output"
|
||||
] = result.output
|
||||
training_handler.save(training_data)
|
||||
else:
|
||||
self._printer.print(
|
||||
content="Invalid train iteration type or agent_id not in training data.",
|
||||
color="red",
|
||||
)
|
||||
else:
|
||||
self._printer.print(
|
||||
content="Crew is None or does not have _train_iteration attribute.",
|
||||
color="red",
|
||||
)
|
||||
# Initialize or retrieve agent's training data
|
||||
agent_training_data = training_data.get(agent_id, {})
|
||||
|
||||
if self.ask_for_human_input and human_feedback is not None:
|
||||
training_data = {
|
||||
if human_feedback is not None:
|
||||
# Save initial output and human feedback
|
||||
agent_training_data[train_iteration] = {
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
"agent": agent_id,
|
||||
"agent_role": self.agent.role, # type: ignore
|
||||
}
|
||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||
train_iteration = self.crew._train_iteration
|
||||
if isinstance(train_iteration, int):
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).append(
|
||||
train_iteration, agent_id, training_data
|
||||
)
|
||||
else:
|
||||
self._printer.print(
|
||||
content="Invalid train iteration type. Expected int.",
|
||||
color="red",
|
||||
)
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
self._printer.print(
|
||||
content="Crew is None or does not have _train_iteration attribute.",
|
||||
content=(
|
||||
f"No existing training data for agent {agent_id} and iteration "
|
||||
f"{train_iteration}. Cannot save improved output."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
# Update the training data and save
|
||||
training_data[agent_id] = agent_training_data
|
||||
training_handler.save(training_data)
|
||||
|
||||
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
@@ -453,82 +491,111 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return {"role": role, "content": prompt}
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""
|
||||
Handles the human feedback loop, allowing the user to provide feedback
|
||||
on the agent's output and determining if additional iterations are needed.
|
||||
"""Handle human feedback with different flows for training vs regular use.
|
||||
|
||||
Parameters:
|
||||
formatted_answer (AgentFinish): The initial output from the agent.
|
||||
Args:
|
||||
formatted_answer: The initial AgentFinish result to get feedback on
|
||||
|
||||
Returns:
|
||||
AgentFinish: The final output after incorporating human feedback.
|
||||
AgentFinish: The final answer after processing feedback
|
||||
"""
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
|
||||
if self._is_training_mode():
|
||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||
|
||||
return self._handle_regular_feedback(formatted_answer, human_feedback)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if crew is in training mode."""
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for training scenarios with single iteration."""
|
||||
self._printer.print(
|
||||
content="\nProcessing training feedback.\n",
|
||||
color="yellow",
|
||||
)
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.messages.append(
|
||||
self._format_msg(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
improved_answer = self._invoke_loop()
|
||||
self._handle_crew_training_output(improved_answer)
|
||||
self.ask_for_human_input = False
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for regular use with potential multiple iterations."""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
while self.ask_for_human_input:
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
response = self._get_llm_feedback_response(feedback)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
self._handle_crew_training_output(formatted_answer, human_feedback)
|
||||
|
||||
# Make an LLM call to verify if additional changes are requested based on human feedback
|
||||
additional_changes_prompt = self._i18n.slice(
|
||||
"human_feedback_classification"
|
||||
).format(feedback=human_feedback)
|
||||
|
||||
retry_count = 0
|
||||
llm_call_successful = False
|
||||
additional_changes_response = None
|
||||
|
||||
while retry_count < MAX_LLM_RETRY and not llm_call_successful:
|
||||
try:
|
||||
additional_changes_response = (
|
||||
self.llm.call(
|
||||
[
|
||||
self._format_msg(
|
||||
additional_changes_prompt, role="system"
|
||||
)
|
||||
],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
llm_call_successful = True
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
|
||||
self._printer.print(
|
||||
content=f"Error during LLM call to classify human feedback: {e}. Retrying... ({retry_count}/{MAX_LLM_RETRY})",
|
||||
color="red",
|
||||
)
|
||||
|
||||
if not llm_call_successful:
|
||||
self._printer.print(
|
||||
content="Error processing feedback after multiple attempts.",
|
||||
color="red",
|
||||
)
|
||||
if not self._feedback_requires_changes(response):
|
||||
self.ask_for_human_input = False
|
||||
break
|
||||
|
||||
if additional_changes_response == "false":
|
||||
self.ask_for_human_input = False
|
||||
elif additional_changes_response == "true":
|
||||
self.ask_for_human_input = True
|
||||
# Add human feedback to messages
|
||||
self.messages.append(self._format_msg(f"Feedback: {human_feedback}"))
|
||||
# Invoke the loop again with updated messages
|
||||
formatted_answer = self._invoke_loop()
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
self._handle_crew_training_output(formatted_answer)
|
||||
else:
|
||||
# Unexpected response
|
||||
self._printer.print(
|
||||
content=f"Unexpected response from LLM: '{additional_changes_response}'. Assuming no additional changes requested.",
|
||||
color="red",
|
||||
)
|
||||
self.ask_for_human_input = False
|
||||
answer = self._process_feedback_iteration(feedback)
|
||||
feedback = self._ask_human_input(answer.output)
|
||||
|
||||
return formatted_answer
|
||||
return answer
|
||||
|
||||
def _get_llm_feedback_response(self, feedback: str) -> Optional[str]:
|
||||
"""Get LLM classification of whether feedback requires changes."""
|
||||
prompt = self._i18n.slice("human_feedback_classification").format(
|
||||
feedback=feedback
|
||||
)
|
||||
message = self._format_msg(prompt, role="system")
|
||||
|
||||
for retry in range(MAX_LLM_RETRY):
|
||||
try:
|
||||
response = self.llm.call([message], callbacks=self.callbacks)
|
||||
return response.strip().lower() if response else None
|
||||
except Exception as error:
|
||||
self._log_feedback_error(retry, error)
|
||||
|
||||
self._log_max_retries_exceeded()
|
||||
return None
|
||||
|
||||
def _feedback_requires_changes(self, response: Optional[str]) -> bool:
|
||||
"""Determine if feedback response indicates need for changes."""
|
||||
return response == "true" if response else False
|
||||
|
||||
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
|
||||
"""Process a single feedback iteration."""
|
||||
self.messages.append(
|
||||
self._format_msg(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
return self._invoke_loop()
|
||||
|
||||
def _log_feedback_error(self, retry_count: int, error: Exception) -> None:
|
||||
"""Log feedback processing errors."""
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"Error processing feedback: {error}. "
|
||||
f"Retrying... ({retry_count + 1}/{MAX_LLM_RETRY})"
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _log_max_retries_exceeded(self) -> None:
|
||||
"""Log when max retries for feedback processing are exceeded."""
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"Failed to process feedback after {MAX_LLM_RETRY} attempts. "
|
||||
"Ending feedback loop."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _handle_max_iterations_exceeded(self, formatted_answer):
|
||||
"""
|
||||
|
||||
@@ -350,7 +350,10 @@ def chat():
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.echo("Starting a conversation with the Crew")
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,52 @@
|
||||
import json
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import click
|
||||
import tomli
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
MIN_REQUIRED_VERSION = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
|
||||
Args:
|
||||
crewai_version: The current version of crewAI.
|
||||
pyproject_data: Dictionary containing pyproject.toml data.
|
||||
|
||||
Returns:
|
||||
bool: True if version check passes, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
|
||||
click.secho(
|
||||
"You are using an older version of crewAI that doesn't support conversational crews. "
|
||||
"Run 'uv upgrade crewai' to get the latest version.",
|
||||
fg="red",
|
||||
)
|
||||
return False
|
||||
except version.InvalidVersion:
|
||||
click.secho("Invalid crewAI version format detected.", fg="red")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
"""
|
||||
@@ -19,20 +54,47 @@ def run_chat():
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
Exits if crew_name or crew_description are missing.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
pyproject_data = read_toml()
|
||||
|
||||
if not check_conversational_crews_version(crewai_version, pyproject_data):
|
||||
return
|
||||
|
||||
crew, crew_name = load_crew_and_name()
|
||||
chat_llm = initialize_chat_llm(crew)
|
||||
if not chat_llm:
|
||||
return
|
||||
|
||||
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
||||
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||
system_message = build_system_message(crew_chat_inputs)
|
||||
|
||||
# Call the LLM to generate the introductory message
|
||||
introductory_message = chat_llm.call(
|
||||
messages=[{"role": "system", "content": system_message}]
|
||||
# Indicate that the crew is being analyzed
|
||||
click.secho(
|
||||
"\nAnalyzing crew and required inputs - this may take 3 to 30 seconds "
|
||||
"depending on the complexity of your crew.",
|
||||
fg="white",
|
||||
)
|
||||
click.secho(f"\nAssistant: {introductory_message}\n", fg="green")
|
||||
|
||||
# Start loading indicator
|
||||
loading_complete = threading.Event()
|
||||
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
|
||||
loading_thread.start()
|
||||
|
||||
try:
|
||||
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
||||
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||
system_message = build_system_message(crew_chat_inputs)
|
||||
|
||||
# Call the LLM to generate the introductory message
|
||||
introductory_message = chat_llm.call(
|
||||
messages=[{"role": "system", "content": system_message}]
|
||||
)
|
||||
finally:
|
||||
# Stop loading indicator
|
||||
loading_complete.set()
|
||||
loading_thread.join()
|
||||
|
||||
# Indicate that the analysis is complete
|
||||
click.secho("\nFinished analyzing crew.\n", fg="white")
|
||||
|
||||
click.secho(f"Assistant: {introductory_message}\n", fg="green")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
@@ -43,15 +105,17 @@ def run_chat():
|
||||
crew_chat_inputs.crew_name: create_tool_function(crew, messages),
|
||||
}
|
||||
|
||||
click.secho(
|
||||
"\nEntering an interactive chat loop with function-calling.\n"
|
||||
"Type 'exit' or Ctrl+C to quit.\n",
|
||||
fg="cyan",
|
||||
)
|
||||
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
print(".", end="", flush=True)
|
||||
time.sleep(1)
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
@@ -85,7 +149,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
"Please keep your responses concise and friendly. "
|
||||
"If a user asks a question outside the crew's scope, provide a brief answer and remind them of the crew's purpose. "
|
||||
"After calling the tool, be prepared to take user feedback and make adjustments as needed. "
|
||||
"If you are ever unsure about a user's request or need clarification, ask the user for more information."
|
||||
"If you are ever unsure about a user's request or need clarification, ask the user for more information. "
|
||||
"Before doing anything else, introduce yourself with a friendly message like: 'Hey! I'm here to help you with [crew's purpose]. Could you please provide me with [inputs] so we can get started?' "
|
||||
"For example: 'Hey! I'm here to help you with uncovering and reporting cutting-edge developments through thorough research and detailed analysis. Could you please provide me with a topic you're interested in? This will help us generate a comprehensive research report and detailed analysis.'"
|
||||
f"\nCrew Name: {crew_chat_inputs.crew_name}"
|
||||
@@ -102,25 +166,33 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
import msvcrt
|
||||
|
||||
while msvcrt.kbhit():
|
||||
msvcrt.getch()
|
||||
else:
|
||||
# Unix-like platforms (Linux, macOS)
|
||||
import termios
|
||||
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
user_input = click.prompt("You", type=str)
|
||||
if user_input.strip().lower() in ["exit", "quit"]:
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
break
|
||||
# Flush any pending input before accepting new input
|
||||
flush_input()
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
final_response = chat_llm.call(
|
||||
messages=messages,
|
||||
tools=[crew_tool_schema],
|
||||
available_functions=available_functions,
|
||||
user_input = get_user_input()
|
||||
handle_user_input(
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
||||
)
|
||||
|
||||
messages.append({"role": "assistant", "content": final_response})
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\nExiting chat. Goodbye!")
|
||||
break
|
||||
@@ -129,6 +201,55 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
break
|
||||
|
||||
|
||||
def get_user_input() -> str:
|
||||
"""Collect multi-line user input with exit handling."""
|
||||
click.secho(
|
||||
"\nYou (type your message below. Press 'Enter' twice when you're done):",
|
||||
fg="blue",
|
||||
)
|
||||
user_input_lines = []
|
||||
while True:
|
||||
line = input()
|
||||
if line.strip().lower() == "exit":
|
||||
return "exit"
|
||||
if line == "":
|
||||
break
|
||||
user_input_lines.append(line)
|
||||
return "\n".join(user_input_lines)
|
||||
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
) -> None:
|
||||
if user_input.strip().lower() == "exit":
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
return
|
||||
|
||||
if not user_input.strip():
|
||||
click.echo("Empty message. Please provide input or type 'exit' to quit.")
|
||||
return
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Indicate that assistant is processing
|
||||
click.echo()
|
||||
click.secho("Assistant is processing your input. Please wait...", fg="green")
|
||||
|
||||
# Process assistant's response
|
||||
final_response = chat_llm.call(
|
||||
messages=messages,
|
||||
tools=[crew_tool_schema],
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
messages.append({"role": "assistant", "content": final_response})
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
@@ -323,10 +444,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description
|
||||
lambda m: m.group(1), task.description or ""
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
@@ -337,10 +458,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
or f"{{{input_name}}}" in agent.backstory
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
@@ -381,18 +502,20 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description
|
||||
lambda m: m.group(1), task.description or ""
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
for agent in crew.agents:
|
||||
# Replace placeholders with input names
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
||||
|
||||
@@ -2,11 +2,7 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.cli.utils import get_crew
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -30,30 +26,35 @@ def reset_memories_command(
|
||||
"""
|
||||
|
||||
try:
|
||||
crew = get_crew()
|
||||
if not crew:
|
||||
raise ValueError("No crew found.")
|
||||
if all:
|
||||
ShortTermMemory().reset()
|
||||
EntityMemory().reset()
|
||||
LongTermMemory().reset()
|
||||
TaskOutputStorageHandler().reset()
|
||||
KnowledgeStorage().reset()
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo("All memories have been reset.")
|
||||
else:
|
||||
if long:
|
||||
LongTermMemory().reset()
|
||||
click.echo("Long term memory has been reset.")
|
||||
return
|
||||
|
||||
if short:
|
||||
ShortTermMemory().reset()
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
EntityMemory().reset()
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
TaskOutputStorageHandler().reset()
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
KnowledgeStorage().reset()
|
||||
click.echo("Knowledge has been reset.")
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
return
|
||||
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo("Long term memory has been reset.")
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo("Knowledge has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
1
src/crewai/cli/templates/crew/.gitignore
vendored
1
src/crewai/cli/templates/crew/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
.env
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.98.0,<1.0.0"
|
||||
"crewai[tools]>=0.100.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
1
src/crewai/cli/templates/flow/.gitignore
vendored
1
src/crewai/cli/templates/flow/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.env
|
||||
__pycache__/
|
||||
lib/
|
||||
.DS_Store
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.98.0,<1.0.0",
|
||||
"crewai[tools]>=0.100.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.98.0"
|
||||
"crewai[tools]>=0.100.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -9,6 +9,7 @@ import tomli
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -247,3 +248,64 @@ def write_env_file(folder_path, env_vars):
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
"""Get the crew instance from the crew.py file."""
|
||||
try:
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
if "crew.py" in files:
|
||||
crew_path = os.path.join(root, "crew.py")
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"crew_module", crew_path
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
try:
|
||||
if callable(attr) and hasattr(attr, "crew"):
|
||||
crew_instance = attr().crew()
|
||||
return crew_instance
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as exec_error:
|
||||
print(f"Error executing module: {exec_error}")
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if require:
|
||||
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@@ -37,7 +38,6 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry import Telemetry
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.types.crew_chat import ChatInputs
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
@@ -84,6 +84,7 @@ class Crew(BaseModel):
|
||||
step_callback: Callback to be executed after each step for every agents execution.
|
||||
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
@@ -182,9 +183,9 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[str] = Field(
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
default=None,
|
||||
description="output_log_file",
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
planning: Optional[bool] = Field(
|
||||
default=False,
|
||||
@@ -210,8 +211,9 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None,
|
||||
description="Knowledge for the crew.",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@@ -289,9 +291,9 @@ class Crew(BaseModel):
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
self._knowledge = Knowledge(
|
||||
self.knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder_config=self.embedder,
|
||||
embedder=self.embedder,
|
||||
collection_name="crew",
|
||||
)
|
||||
|
||||
@@ -492,21 +494,26 @@ class Crew(BaseModel):
|
||||
train_crew = self.copy()
|
||||
train_crew._setup_for_training(filename)
|
||||
|
||||
for n_iteration in range(n_iterations):
|
||||
train_crew._train_iteration = n_iteration
|
||||
train_crew.kickoff(inputs=inputs)
|
||||
try:
|
||||
for n_iteration in range(n_iterations):
|
||||
train_crew._train_iteration = n_iteration
|
||||
train_crew.kickoff(inputs=inputs)
|
||||
|
||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
|
||||
for agent in train_crew.agents:
|
||||
if training_data.get(str(agent.id)):
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
for agent in train_crew.agents:
|
||||
if training_data.get(str(agent.id)):
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Training failed: {e}", color="red")
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).clear()
|
||||
CrewTrainingHandler(filename).clear()
|
||||
raise
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
@@ -674,12 +681,7 @@ class Crew(BaseModel):
|
||||
manager.tools = []
|
||||
raise Exception("Manager agent should not have tools")
|
||||
else:
|
||||
self.manager_llm = (
|
||||
getattr(self.manager_llm, "model_name", None)
|
||||
or getattr(self.manager_llm, "model", None)
|
||||
or getattr(self.manager_llm, "deployment_name", None)
|
||||
or self.manager_llm
|
||||
)
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
manager = Agent(
|
||||
role=i18n.retrieve("hierarchical_manager_agent", "role"),
|
||||
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
|
||||
@@ -1011,8 +1013,8 @@ class Crew(BaseModel):
|
||||
return result
|
||||
|
||||
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
|
||||
if self._knowledge:
|
||||
return self._knowledge.query(query)
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(query)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
@@ -1056,6 +1058,8 @@ class Crew(BaseModel):
|
||||
"_telemetry",
|
||||
"agents",
|
||||
"tasks",
|
||||
"knowledge_sources",
|
||||
"knowledge",
|
||||
}
|
||||
|
||||
cloned_agents = [agent.copy() for agent in self.agents]
|
||||
@@ -1063,6 +1067,9 @@ class Crew(BaseModel):
|
||||
task_mapping = {}
|
||||
|
||||
cloned_tasks = []
|
||||
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
|
||||
existing_knowledge = shallow_copy(self.knowledge)
|
||||
|
||||
for task in self.tasks:
|
||||
cloned_task = task.copy(cloned_agents, task_mapping)
|
||||
cloned_tasks.append(cloned_task)
|
||||
@@ -1082,7 +1089,13 @@ class Crew(BaseModel):
|
||||
copied_data.pop("agents", None)
|
||||
copied_data.pop("tasks", None)
|
||||
|
||||
copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks)
|
||||
copied_crew = Crew(
|
||||
**copied_data,
|
||||
agents=cloned_agents,
|
||||
tasks=cloned_tasks,
|
||||
knowledge_sources=existing_knowledge_sources,
|
||||
knowledge=existing_knowledge,
|
||||
)
|
||||
|
||||
return copied_crew
|
||||
|
||||
@@ -1154,3 +1167,80 @@ class Crew(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
"""Reset specific or all memories for the crew.
|
||||
|
||||
Args:
|
||||
command_type: Type of memory to reset.
|
||||
Valid options: 'long', 'short', 'entity', 'knowledge',
|
||||
'kickoff_outputs', or 'all'
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid command type is provided.
|
||||
RuntimeError: If memory reset operation fails.
|
||||
"""
|
||||
VALID_TYPES = frozenset(
|
||||
["long", "short", "entity", "knowledge", "kickoff_outputs", "all"]
|
||||
)
|
||||
|
||||
if command_type not in VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
|
||||
)
|
||||
|
||||
try:
|
||||
if command_type == "all":
|
||||
self._reset_all_memories()
|
||||
else:
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
self._logger.log("info", f"{command_type} memory has been reset")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = [
|
||||
("short term", self._short_term_memory),
|
||||
("entity", self._entity_memory),
|
||||
("long term", self._long_term_memory),
|
||||
("task output", self._task_output_handler),
|
||||
("knowledge", self.knowledge),
|
||||
]
|
||||
|
||||
for name, system in memory_systems:
|
||||
if system is not None:
|
||||
try:
|
||||
system.reset()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to reset
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the specified memory system fails to reset
|
||||
"""
|
||||
reset_functions = {
|
||||
"long": (self._long_term_memory, "long term"),
|
||||
"short": (self._short_term_memory, "short term"),
|
||||
"entity": (self._entity_memory, "entity"),
|
||||
"knowledge": (self.knowledge, "knowledge"),
|
||||
"kickoff_outputs": (self._task_output_handler, "task output"),
|
||||
}
|
||||
|
||||
memory_system, name = reset_functions[memory_type]
|
||||
if memory_system is None:
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
memory_system.reset()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
@@ -447,14 +447,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def __init__(
|
||||
self,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
restore_uuid: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
|
||||
Args:
|
||||
persistence: Optional persistence backend for storing flow states
|
||||
restore_uuid: Optional UUID to restore state from persistence
|
||||
**kwargs: Additional state values to initialize or override
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
@@ -464,64 +462,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
|
||||
# Validate state model before initialization
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, BaseModel) and not issubclass(
|
||||
self.initial_state, FlowState
|
||||
):
|
||||
# Check if model has id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
|
||||
# Handle persistence and potential ID conflicts
|
||||
stored_state = None
|
||||
if self._persistence is not None:
|
||||
if (
|
||||
restore_uuid
|
||||
and kwargs
|
||||
and "id" in kwargs
|
||||
and restore_uuid != kwargs["id"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
|
||||
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
|
||||
)
|
||||
|
||||
# Attempt to load state, prioritizing restore_uuid
|
||||
if restore_uuid:
|
||||
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
if not stored_state:
|
||||
raise ValueError(
|
||||
f"No state found for restore_uuid='{restore_uuid}'"
|
||||
)
|
||||
elif kwargs and "id" in kwargs:
|
||||
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
|
||||
stored_state = self._persistence.load_state(kwargs["id"])
|
||||
if not stored_state:
|
||||
# For kwargs["id"], we allow creating new state if not found
|
||||
self._state = self._create_initial_state()
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
return
|
||||
|
||||
# Initialize state based on persistence and kwargs
|
||||
if stored_state:
|
||||
# Create initial state and restore from persistence
|
||||
self._state = self._create_initial_state()
|
||||
self._restore_state(stored_state)
|
||||
# Apply any additional kwargs to override specific fields
|
||||
if kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
|
||||
if filtered_kwargs:
|
||||
self._initialize_state(filtered_kwargs)
|
||||
else:
|
||||
# No stored state, create new state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
|
||||
self._telemetry.flow_creation_span(self.__class__.__name__)
|
||||
|
||||
@@ -635,18 +581,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
@property
|
||||
def flow_id(self) -> str:
|
||||
"""Returns the unique identifier of this flow instance.
|
||||
|
||||
|
||||
This property provides a consistent way to access the flow's unique identifier
|
||||
regardless of the underlying state implementation (dict or BaseModel).
|
||||
|
||||
|
||||
Returns:
|
||||
str: The flow's unique identifier, or an empty string if not found
|
||||
|
||||
|
||||
Note:
|
||||
This property safely handles both dictionary and BaseModel state types,
|
||||
returning an empty string if the ID cannot be retrieved rather than raising
|
||||
an exception.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = MyFlow()
|
||||
@@ -654,9 +600,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
```
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, '_state'):
|
||||
if not hasattr(self, "_state"):
|
||||
return ""
|
||||
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
elif isinstance(self._state, BaseModel):
|
||||
@@ -731,7 +677,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""
|
||||
# When restoring from persistence, use the stored ID
|
||||
stored_id = stored_state.get("id")
|
||||
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
|
||||
if not stored_id:
|
||||
raise ValueError("Stored state must have an 'id' field")
|
||||
|
||||
@@ -755,6 +700,41 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Start the flow execution.
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and potentially a state ID to restore
|
||||
"""
|
||||
# Handle state restoration if ID is provided in inputs
|
||||
if inputs and "id" in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs["id"]
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
|
||||
# Override the id in the state if it exists in inputs
|
||||
if "id" in inputs:
|
||||
if isinstance(self._state, dict):
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
|
||||
if stored_state:
|
||||
self._log_flow_event(
|
||||
f"Loading flow state from memory for UUID: {restore_uuid}",
|
||||
color="yellow",
|
||||
)
|
||||
# Restore the state
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
)
|
||||
|
||||
# Apply any additional inputs after restoration
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
|
||||
if filtered_inputs:
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
# Start flow execution
|
||||
self.event_emitter.send(
|
||||
self,
|
||||
event=FlowStartedEvent(
|
||||
@@ -762,10 +742,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
flow_name=self.__class__.__name__,
|
||||
),
|
||||
)
|
||||
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
)
|
||||
|
||||
if inputs is not None:
|
||||
if inputs is not None and "id" not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
return asyncio.run(self.kickoff_async())
|
||||
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
@@ -1008,20 +991,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
This method provides a consistent interface for logging flow-related events,
|
||||
combining both console output with colors and proper logging levels.
|
||||
|
||||
|
||||
Args:
|
||||
message: The message to log
|
||||
color: Color to use for console output (default: yellow)
|
||||
Available colors: purple, red, bold_green, bold_purple,
|
||||
bold_blue, yellow, bold_yellow
|
||||
bold_blue, yellow, yellow
|
||||
level: Log level to use (default: info)
|
||||
Supported levels: info, warning
|
||||
|
||||
|
||||
Note:
|
||||
This method uses the Printer utility for colored console output
|
||||
and the standard logging module for log level support.
|
||||
@@ -1031,7 +1016,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
logger.info(message)
|
||||
elif level == "warning":
|
||||
logger.warning(message)
|
||||
|
||||
|
||||
def plot(self, filename: str = "crewai_flow") -> None:
|
||||
self._telemetry.flow_plotting_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
|
||||
@@ -54,57 +54,44 @@ LOG_MESSAGES = {
|
||||
|
||||
class PersistenceDecorator:
|
||||
"""Class to handle flow state persistence with consistent logging."""
|
||||
|
||||
|
||||
_printer = Printer() # Class-level printer instance
|
||||
|
||||
|
||||
@classmethod
|
||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
|
||||
This method handles the persistence of flow state data, including proper
|
||||
error handling and colored console output for status updates.
|
||||
|
||||
|
||||
Args:
|
||||
flow_instance: The flow instance whose state to persist
|
||||
method_name: Name of the method that triggered persistence
|
||||
persistence_instance: The persistence backend to use
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If flow has no state or state lacks an ID
|
||||
RuntimeError: If state persistence fails
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
|
||||
Note:
|
||||
Uses bold_yellow color for success messages and red for errors.
|
||||
All operations are logged at appropriate levels (info/error).
|
||||
|
||||
Example:
|
||||
```python
|
||||
@persist
|
||||
def my_flow_method(self):
|
||||
# Method implementation
|
||||
pass
|
||||
# State will be automatically persisted after method execution
|
||||
```
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
|
||||
# Log state saving with consistent message
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
|
||||
try:
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
@@ -154,44 +141,79 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
def begin(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
class_methods = {}
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and hasattr(method, "__is_flow_method__"):
|
||||
# Wrap each flow method with persistence
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
method_coro = method(self, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||
return result
|
||||
class_methods[name] = class_async_wrapper
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||
return result
|
||||
class_methods[name] = class_sync_wrapper
|
||||
original_init = getattr(target, "__init__")
|
||||
|
||||
# Preserve flow-specific attributes
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and (
|
||||
hasattr(method, "__is_start_method__") or
|
||||
hasattr(method, "__trigger_methods__") or
|
||||
hasattr(method, "__condition_type__") or
|
||||
hasattr(method, "__is_flow_method__") or
|
||||
hasattr(method, "__is_router__")
|
||||
):
|
||||
original_methods[name] = method
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
|
||||
return result
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(class_methods[name], attr, getattr(method, attr))
|
||||
setattr(class_methods[name], "__is_flow_method__", True)
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
|
||||
return result
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
# Update class with wrapped methods
|
||||
for name, method in class_methods.items():
|
||||
setattr(target, name, method)
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
@@ -208,6 +230,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
@@ -219,6 +242,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
|
||||
@@ -3,10 +3,9 @@ SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,34 +15,34 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""SQLite-based implementation of flow state persistence.
|
||||
|
||||
|
||||
This class provides a simple, file-based persistence implementation using SQLite.
|
||||
It's suitable for development and testing, or for production use cases with
|
||||
moderate performance requirements.
|
||||
"""
|
||||
|
||||
|
||||
db_path: str # Type annotation for instance variable
|
||||
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file. If not provided, uses
|
||||
db_storage_path() from utilities.paths.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If db_path is invalid
|
||||
"""
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
# Get path from argument or default location
|
||||
path = db_path or db_storage_path()
|
||||
|
||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||
|
||||
if not path:
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self.init_db()
|
||||
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -58,10 +57,10 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
""")
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
ON flow_states(flow_uuid)
|
||||
""")
|
||||
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
@@ -69,7 +68,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
@@ -84,7 +83,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO flow_states (
|
||||
@@ -99,13 +98,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
datetime.utcnow().isoformat(),
|
||||
json.dumps(state_dict),
|
||||
))
|
||||
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
@@ -118,7 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
LIMIT 1
|
||||
""", (flow_uuid,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
|
||||
if row:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
@@ -15,20 +15,20 @@ class Knowledge(BaseModel):
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data,
|
||||
):
|
||||
@@ -37,25 +37,23 @@ class Knowledge(BaseModel):
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder_config=embedder_config, collection_name=collection_name
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
for source in sources:
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
self._add_sources()
|
||||
|
||||
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
|
||||
results = self.storage.search(
|
||||
query,
|
||||
limit,
|
||||
@@ -63,6 +61,15 @@ class Knowledge(BaseModel):
|
||||
return results
|
||||
|
||||
def _add_sources(self):
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -29,7 +29,13 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
def validate_file_path(cls, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if v is None and info.data.get("file_path" if info.field_name == "file_paths" else "file_paths") is None:
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths"
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
|
||||
@@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder_config)
|
||||
self._set_embedder_config(embedder)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -99,7 +99,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
if self.app:
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=collection_name, embedding_function=self.embedder_config
|
||||
name=collection_name, embedding_function=self.embedder
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
@@ -187,17 +187,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(
|
||||
self, embedder_config: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder_config = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder_config)
|
||||
if embedder_config
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -5,15 +5,17 @@ import sys
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
from litellm import Choices, get_supported_openai_params
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -128,21 +130,23 @@ class LLM:
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Dict[str, Any]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stop = stop
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
@@ -153,44 +157,83 @@ class LLM:
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
High-level call method that:
|
||||
1) Calls litellm.completion
|
||||
2) Checks for function/tool calls
|
||||
3) If a tool call is found:
|
||||
a) executes the function
|
||||
b) returns the result
|
||||
4) If no tool call, returns the text response
|
||||
High-level llm call method that:
|
||||
1) Accepts either a string or a list of messages
|
||||
2) Converts string input to the required message format
|
||||
3) Calls litellm.completion
|
||||
4) Handles function/tool calls if any
|
||||
5) Returns the final text response or tool result
|
||||
|
||||
:param messages: The conversation messages
|
||||
:param tools: Optional list of function schemas for function calling
|
||||
:param callbacks: Optional list of callbacks
|
||||
:param available_functions: A dictionary mapping function_name -> actual Python function
|
||||
:return: Final text response from the LLM or the tool result
|
||||
Parameters:
|
||||
- messages (Union[str, List[Dict[str, str]]]): The input messages for the LLM.
|
||||
- If a string is provided, it will be converted into a message list with a single entry.
|
||||
- If a list of dictionaries is provided, each dictionary should have 'role' and 'content' keys.
|
||||
- tools (Optional[List[dict]]): A list of tool schemas for function calling.
|
||||
- callbacks (Optional[List[Any]]): A list of callback functions to be executed.
|
||||
- available_functions (Optional[Dict[str, Any]]): A dictionary mapping function names to actual Python functions.
|
||||
|
||||
Returns:
|
||||
- str: The final text response from the LLM or the result of a tool function call.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
# Example 1: Using a string input
|
||||
response = llm.call("Return the name of a random city in the world.")
|
||||
print(response)
|
||||
|
||||
# Example 2: Using a list of messages
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
response = llm.call(messages)
|
||||
print(response)
|
||||
"""
|
||||
# Validate parameters before proceeding with the call.
|
||||
self._validate_call_params()
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# For O1 models, system messages are not supported.
|
||||
# Convert any system messages into assistant messages.
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 1) Make the completion call
|
||||
# --- 1) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -207,23 +250,28 @@ class LLM:
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_base": self.api_base,
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
"tools": tools, # pass the tool schema
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
# Remove None values from params
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
# --- 2) Make the completion call
|
||||
response = litellm.completion(**params)
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# Ensure callbacks get the full response object with usage info
|
||||
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
@@ -236,11 +284,11 @@ class LLM:
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
# --- 2) If no tool calls, return the text response
|
||||
# --- 4) If no tool calls, return the text response
|
||||
if not tool_calls or not available_functions:
|
||||
return text_response
|
||||
|
||||
# --- 3) Handle the tool call
|
||||
# --- 5) Handle the tool call
|
||||
tool_call = tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
@@ -255,7 +303,6 @@ class LLM:
|
||||
try:
|
||||
# Call the actual tool function
|
||||
result = fn(**function_args)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -277,6 +324,36 @@ class LLM:
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_custom_llm_provider(self) -> str:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
|
||||
- If there is no '/', defaults to "openai".
|
||||
"""
|
||||
if "/" in self.model:
|
||||
return self.model.split("/")[0]
|
||||
return "openai"
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
a response_format is provided and whether the model supports it.
|
||||
The custom_llm_provider is dynamically determined from the model:
|
||||
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
|
||||
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||
- If no slash is present, "openai" is assumed.
|
||||
"""
|
||||
provider = self._get_custom_llm_provider()
|
||||
if self.response_format is not None and not supports_response_schema(
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
):
|
||||
raise ValueError(
|
||||
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||
"Please remove response_format or use a supported model."
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -10,13 +14,15 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -36,11 +42,13 @@ class EntityMemory(Memory):
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
|
||||
@@ -17,7 +17,7 @@ class LongTermMemory(Memory):
|
||||
def __init__(self, storage=None, path=None):
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: RAGStorage):
|
||||
self.storage = storage
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
storage: Any
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -14,13 +16,15 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -39,7 +43,8 @@ class ShortTermMemory(Memory):
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -48,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -23,7 +23,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()).parent / "latest_kickoff_task_outputs.db")
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -17,7 +17,7 @@ class LTMSQLiteStorage:
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()).parent / "long_term_memory_storage.db")
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
|
||||
@@ -423,6 +423,10 @@ class Task(BaseModel):
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
|
||||
if self._execution_span:
|
||||
self._telemetry.task_ended(self._execution_span, self, agent.crew)
|
||||
self._execution_span = None
|
||||
@@ -431,7 +435,9 @@ class Task(BaseModel):
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else pydantic_output.model_dump_json() if pydantic_output else result
|
||||
else pydantic_output.model_dump_json()
|
||||
if pydantic_output
|
||||
else result
|
||||
)
|
||||
self._save_file(content)
|
||||
|
||||
@@ -452,7 +458,7 @@ class Task(BaseModel):
|
||||
return "\n".join(tasks_slices)
|
||||
|
||||
def interpolate_inputs_and_add_conversation_history(
|
||||
self, inputs: Dict[str, Union[str, int, float]]
|
||||
self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]]
|
||||
) -> None:
|
||||
"""Interpolate inputs into the task description, expected output, and output file path.
|
||||
Add conversation history if present.
|
||||
@@ -524,7 +530,9 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
def interpolate_only(
|
||||
self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]]
|
||||
self,
|
||||
input_string: Optional[str],
|
||||
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]],
|
||||
) -> str:
|
||||
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.
|
||||
|
||||
@@ -532,17 +540,39 @@ class Task(BaseModel):
|
||||
input_string: The string containing template variables to interpolate.
|
||||
Can be None or empty, in which case an empty string is returned.
|
||||
inputs: Dictionary mapping template variables to their values.
|
||||
Supported value types are strings, integers, and floats.
|
||||
If input_string is empty or has no placeholders, inputs can be empty.
|
||||
Supported value types are strings, integers, floats, and dicts/lists
|
||||
containing only these types and other nested dicts/lists.
|
||||
|
||||
Returns:
|
||||
The interpolated string with all template variables replaced with their values.
|
||||
Empty string if input_string is None or empty.
|
||||
|
||||
Raises:
|
||||
ValueError: If a required template variable is missing from inputs.
|
||||
KeyError: If a template variable is not found in the inputs dictionary.
|
||||
ValueError: If a value contains unsupported types
|
||||
"""
|
||||
|
||||
# Validation function for recursive type checking
|
||||
def validate_type(value: Any) -> None:
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return
|
||||
if isinstance(value, (dict, list)):
|
||||
for item in value.values() if isinstance(value, dict) else value:
|
||||
validate_type(item)
|
||||
return
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(value).__name__} in inputs. "
|
||||
"Only str, int, float, bool, dict, and list are allowed."
|
||||
)
|
||||
|
||||
# Validate all input values
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
validate_type(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e
|
||||
|
||||
if input_string is None or not input_string:
|
||||
return ""
|
||||
if "{" not in input_string and "}" not in input_string:
|
||||
@@ -551,15 +581,7 @@ class Task(BaseModel):
|
||||
raise ValueError(
|
||||
"Inputs dictionary cannot be empty when interpolating variables"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate input types
|
||||
for key, value in inputs.items():
|
||||
if not isinstance(value, (str, int, float)):
|
||||
raise ValueError(
|
||||
f"Value for key '{key}' must be a string, integer, or float, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
escaped_string = input_string.replace("{", "{{").replace("}", "}}")
|
||||
|
||||
for key in inputs.keys():
|
||||
|
||||
@@ -7,11 +7,11 @@ from crewai.utilities import I18N
|
||||
|
||||
i18n = I18N()
|
||||
|
||||
|
||||
class AddImageToolSchema(BaseModel):
|
||||
image_url: str = Field(..., description="The URL or path of the image to add")
|
||||
action: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional context or question about the image"
|
||||
default=None, description="Optional context or question about the image"
|
||||
)
|
||||
|
||||
|
||||
@@ -36,10 +36,7 @@ class AddImageTool(BaseTool):
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import ast
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from json import JSONDecodeError
|
||||
from textwrap import dedent
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import json5
|
||||
from json_repair import repair_json
|
||||
|
||||
import crewai.utilities.events as events
|
||||
@@ -407,28 +408,55 @@ class ToolUsage:
|
||||
)
|
||||
return self._tool_calling(tool_string)
|
||||
|
||||
def _validate_tool_input(self, tool_input: str) -> Dict[str, Any]:
|
||||
def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]:
|
||||
if tool_input is None:
|
||||
return {}
|
||||
|
||||
if not isinstance(tool_input, str) or not tool_input.strip():
|
||||
raise Exception(
|
||||
"Tool input must be a valid dictionary in JSON or Python literal format"
|
||||
)
|
||||
|
||||
# Attempt 1: Parse as JSON
|
||||
try:
|
||||
# Replace Python literals with JSON equivalents
|
||||
replacements = {
|
||||
r"'": '"',
|
||||
r"None": "null",
|
||||
r"True": "true",
|
||||
r"False": "false",
|
||||
}
|
||||
for pattern, replacement in replacements.items():
|
||||
tool_input = re.sub(pattern, replacement, tool_input)
|
||||
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
# Attempt to repair JSON string
|
||||
repaired_input = repair_json(tool_input)
|
||||
try:
|
||||
arguments = json.loads(repaired_input)
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"Invalid tool input JSON: {e}")
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
|
||||
return arguments
|
||||
# Attempt 2: Parse as Python literal
|
||||
try:
|
||||
arguments = ast.literal_eval(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (ValueError, SyntaxError):
|
||||
pass # Continue to the next parsing attempt
|
||||
|
||||
# Attempt 3: Parse as JSON5
|
||||
try:
|
||||
arguments = json5.loads(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, ValueError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
|
||||
# Attempt 4: Repair JSON
|
||||
try:
|
||||
repaired_input = repair_json(tool_input)
|
||||
self._printer.print(
|
||||
content=f"Repaired JSON: {repaired_input}", color="blue"
|
||||
)
|
||||
arguments = json.loads(repaired_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except Exception as e:
|
||||
self._printer.print(content=f"Failed to repair JSON: {e}", color="red")
|
||||
|
||||
# If all parsing attempts fail, raise an error
|
||||
raise Exception(
|
||||
"Tool input must be a valid dictionary in JSON or Python literal format"
|
||||
)
|
||||
|
||||
def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None:
|
||||
event_data = self._prepare_event_data(tool, tool_calling)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"final_answer_format": "If you don't need to use any more tools, you must give your best complete final answer, make sure it satisfies the expected criteria, use the EXACT format below:\n\n```\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\n\n```",
|
||||
"format_without_tools": "\nSorry, I didn't use the right format. I MUST either use a tool (among the available ones), OR give my best final answer.\nHere is the expected format I must follow:\n\n```\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n```\n This Thought/Action/Action Input/Result process can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",
|
||||
"task_with_context": "{task}\n\nThis is the context you're working with:\n{context}",
|
||||
"expected_output": "\nThis is the expect criteria for your final answer: {expected_output}\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"expected_output": "\nThis is the expected criteria for your final answer: {expected_output}\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"human_feedback": "You got human feedback on your work, re-evaluate it and give a new Final Answer when ready.\n {human_feedback}",
|
||||
"getting_input": "This is the agent's final answer: {final_answer}\n\n",
|
||||
"summarizer_system_message": "You are a helpful assistant that summarizes text.",
|
||||
@@ -24,7 +24,8 @@
|
||||
"manager_request": "Your best answer to your coworker asking you this, accounting for the context shared.",
|
||||
"formatted_task_instructions": "Ensure your final answer contains only the content in the following format: {output_format}\n\nEnsure the final output does not include any code block markers like ```json or ```python.",
|
||||
"human_feedback_classification": "Determine if the following feedback indicates that the user is satisfied or if further changes are needed. Respond with 'True' if further changes are needed, or 'False' if the user is satisfied. **Important** Do not include any additional commentary outside of your 'True' or 'False' response.\n\nFeedback: \"{feedback}\"",
|
||||
"conversation_history_instruction": "You are a member of a crew collaborating to achieve a common goal. Your task is a specific action that contributes to this larger objective. For additional context, please review the conversation history between you and the user that led to the initiation of this crew. Use any relevant information or feedback from the conversation to inform your task execution and ensure your response aligns with both the immediate task and the crew's overall goals."
|
||||
"conversation_history_instruction": "You are a member of a crew collaborating to achieve a common goal. Your task is a specific action that contributes to this larger objective. For additional context, please review the conversation history between you and the user that led to the initiation of this crew. Use any relevant information or feedback from the conversation to inform your task execution and ensure your response aligns with both the immediate task and the crew's overall goals.",
|
||||
"feedback_instructions": "User feedback: {feedback}\nInstructions: Use this feedback to enhance the next output iteration.\nNote: Do not respond or add commentary."
|
||||
},
|
||||
"errors": {
|
||||
"force_final_answer_error": "You can't keep going, here is the best final answer you generated:\n\n {formatted_answer}",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Dict[str, Any] | None = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
@@ -30,21 +31,19 @@ class EmbeddingConfigurator:
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
|
||||
if isinstance(provider, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(provider)
|
||||
return provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
return self.embedding_functions[provider](config, model_name)
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
@@ -65,6 +64,13 @@ class EmbeddingConfigurator:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -79,6 +85,10 @@ class EmbeddingConfigurator:
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -101,6 +111,8 @@ class EmbeddingConfigurator:
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -112,6 +124,7 @@ class EmbeddingConfigurator:
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -142,9 +155,11 @@ class EmbeddingConfigurator:
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
return AmazonBedrockEmbeddingFunction(
|
||||
session=config.get("session"),
|
||||
)
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
@@ -194,3 +209,28 @@ class EmbeddingConfigurator:
|
||||
raise e
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
|
||||
@@ -92,13 +92,34 @@ class TaskEvaluator:
|
||||
"""
|
||||
|
||||
output_training_data = training_data[agent_id]
|
||||
|
||||
final_aggregated_data = ""
|
||||
for _, data in output_training_data.items():
|
||||
|
||||
for iteration, data in output_training_data.items():
|
||||
improved_output = data.get("improved_output")
|
||||
initial_output = data.get("initial_output")
|
||||
human_feedback = data.get("human_feedback")
|
||||
|
||||
if not all([improved_output, initial_output, human_feedback]):
|
||||
missing_fields = [
|
||||
field
|
||||
for field in ["improved_output", "initial_output", "human_feedback"]
|
||||
if not data.get(field)
|
||||
]
|
||||
error_msg = (
|
||||
f"Critical training data error: Missing fields ({', '.join(missing_fields)}) "
|
||||
f"for agent {agent_id} in iteration {iteration}.\n"
|
||||
"This indicates a broken training process. "
|
||||
"Cannot proceed with evaluation.\n"
|
||||
"Please check your training implementation."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
final_aggregated_data += (
|
||||
f"Initial Output:\n{data['initial_output']}\n\n"
|
||||
f"Human Feedback:\n{data['human_feedback']}\n\n"
|
||||
f"Improved Output:\n{data['improved_output']}\n\n"
|
||||
f"Iteration: {iteration}\n"
|
||||
f"Initial Output:\n{initial_output}\n\n"
|
||||
f"Human Feedback:\n{human_feedback}\n\n"
|
||||
f"Improved Output:\n{improved_output}\n\n"
|
||||
"------------------------------------------------\n\n"
|
||||
)
|
||||
|
||||
evaluation_query = (
|
||||
|
||||
@@ -1,30 +1,64 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
|
||||
class FileHandler:
|
||||
"""take care of file operations, currently it only logs messages to a file"""
|
||||
"""Handler for file operations supporting both JSON and text-based logging.
|
||||
|
||||
Args:
|
||||
file_path (Union[bool, str]): Path to the log file or boolean flag
|
||||
"""
|
||||
|
||||
def __init__(self, file_path):
|
||||
if isinstance(file_path, bool):
|
||||
def __init__(self, file_path: Union[bool, str]):
|
||||
self._initialize_path(file_path)
|
||||
|
||||
def _initialize_path(self, file_path: Union[bool, str]):
|
||||
if file_path is True: # File path is boolean True
|
||||
self._path = os.path.join(os.curdir, "logs.txt")
|
||||
elif isinstance(file_path, str):
|
||||
self._path = file_path
|
||||
|
||||
elif isinstance(file_path, str): # File path is a string
|
||||
if file_path.endswith((".json", ".txt")):
|
||||
self._path = file_path # No modification if the file ends with .json or .txt
|
||||
else:
|
||||
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
else:
|
||||
raise ValueError("file_path must be either a boolean or a string.")
|
||||
|
||||
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs):
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message + "\n")
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
# If the file is empty, start with a list; else, append to it
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, "r", encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {str(e)}")
|
||||
|
||||
class PickleHandler:
|
||||
def __init__(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -24,12 +24,10 @@ def create_llm(
|
||||
|
||||
# 1) If llm_value is already an LLM object, return it directly
|
||||
if isinstance(llm_value, LLM):
|
||||
print("LLM value is already an LLM object")
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
print("LLM value is a string")
|
||||
try:
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
@@ -39,12 +37,10 @@ def create_llm(
|
||||
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
print("LLM value is None")
|
||||
return _llm_via_environment_or_fallback()
|
||||
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
try:
|
||||
print("LLM value is an unknown object")
|
||||
# Extract attributes with explicit types
|
||||
model = (
|
||||
getattr(llm_value, "model_name", None)
|
||||
@@ -57,6 +53,7 @@ def create_llm(
|
||||
timeout: Optional[float] = getattr(llm_value, "timeout", None)
|
||||
api_key: Optional[str] = getattr(llm_value, "api_key", None)
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
@@ -66,6 +63,7 @@ def create_llm(
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
@@ -105,8 +103,18 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
callbacks: List[Any] = []
|
||||
|
||||
# Optional base URL from env
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
|
||||
if api_base:
|
||||
base_url = (
|
||||
os.environ.get("BASE_URL")
|
||||
or os.environ.get("OPENAI_API_BASE")
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
)
|
||||
|
||||
api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE")
|
||||
|
||||
# Synchronize base_url and api_base if one is populated and the other is not
|
||||
if base_url and not api_base:
|
||||
api_base = base_url
|
||||
elif api_base and not base_url:
|
||||
base_url = api_base
|
||||
|
||||
# Initialize llm_params dictionary
|
||||
@@ -119,6 +127,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
"timeout": timeout,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
|
||||
@@ -7,7 +7,7 @@ import appdirs
|
||||
|
||||
def db_storage_path() -> str:
|
||||
"""Returns the path for SQLite database storage.
|
||||
|
||||
|
||||
Returns:
|
||||
str: Full path to the SQLite database file
|
||||
"""
|
||||
@@ -16,7 +16,7 @@ def db_storage_path() -> str:
|
||||
|
||||
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(data_dir / "crewai_flows.db")
|
||||
return str(data_dir)
|
||||
|
||||
|
||||
def get_project_directory_name():
|
||||
@@ -28,4 +28,4 @@ def get_project_directory_name():
|
||||
else:
|
||||
cwd = Path.cwd()
|
||||
project_directory_name = cwd.name
|
||||
return project_directory_name
|
||||
return project_directory_name
|
||||
@@ -21,6 +21,16 @@ class Printer:
|
||||
self._print_yellow(content)
|
||||
elif color == "bold_yellow":
|
||||
self._print_bold_yellow(content)
|
||||
elif color == "cyan":
|
||||
self._print_cyan(content)
|
||||
elif color == "bold_cyan":
|
||||
self._print_bold_cyan(content)
|
||||
elif color == "magenta":
|
||||
self._print_magenta(content)
|
||||
elif color == "bold_magenta":
|
||||
self._print_bold_magenta(content)
|
||||
elif color == "green":
|
||||
self._print_green(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
@@ -44,3 +54,18 @@ class Printer:
|
||||
|
||||
def _print_bold_yellow(self, content):
|
||||
print("\033[1m\033[93m {}\033[00m".format(content))
|
||||
|
||||
def _print_cyan(self, content):
|
||||
print("\033[96m {}\033[00m".format(content))
|
||||
|
||||
def _print_bold_cyan(self, content):
|
||||
print("\033[1m\033[96m {}\033[00m".format(content))
|
||||
|
||||
def _print_magenta(self, content):
|
||||
print("\033[35m {}\033[00m".format(content))
|
||||
|
||||
def _print_bold_magenta(self, content):
|
||||
print("\033[1m\033[35m {}\033[00m".format(content))
|
||||
|
||||
def _print_green(self, content):
|
||||
print("\033[32m {}\033[00m".format(content))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
|
||||
@@ -29,3 +31,8 @@ class CrewTrainingHandler(PickleHandler):
|
||||
data[agent_id] = {train_iteration: new_data}
|
||||
|
||||
self.save(data)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the training data by removing the file or resetting its contents."""
|
||||
if os.path.exists(self.file_path):
|
||||
self.save({})
|
||||
|
||||
Reference in New Issue
Block a user