mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 13:18:29 +00:00
Compare commits
8 Commits
devin/1746
...
lg-pytest-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c96301fe0d | ||
|
|
ea52e6fc8f | ||
|
|
bf11f30f7a | ||
|
|
b453fd3995 | ||
|
|
e1961dccae | ||
|
|
7b903414cd | ||
|
|
133e5892b9 | ||
|
|
e3ab999a57 |
@@ -11,7 +11,7 @@ dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.68.0",
|
||||
"litellm==1.67.1",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
|
||||
@@ -201,22 +201,9 @@ def install(context):
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--record",
|
||||
is_flag=True,
|
||||
help="Record LLM responses for later replay",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
is_flag=True,
|
||||
help="Replay from recorded LLM responses without making network calls",
|
||||
)
|
||||
def run(record: bool = False, replay: bool = False):
|
||||
def run():
|
||||
"""Run the Crew."""
|
||||
if record and replay:
|
||||
raise click.UsageError("Cannot use --record and --replay simultaneously")
|
||||
click.echo("Running the Crew")
|
||||
run_crew(record=record, replay=replay)
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -2,7 +2,7 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crews
|
||||
from crewai.cli.utils import get_crew
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -26,47 +26,35 @@ def reset_memories_command(
|
||||
"""
|
||||
|
||||
try:
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
|
||||
crew = get_crew()
|
||||
if not crew:
|
||||
raise ValueError("No crew found.")
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo("All memories have been reset.")
|
||||
return
|
||||
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
|
||||
)
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
|
||||
)
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo("Long term memory has been reset.")
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo("Knowledge has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -14,17 +14,13 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
record (bool, optional): Whether to record LLM responses. Defaults to False.
|
||||
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -48,24 +44,17 @@ def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type, record, replay)
|
||||
execute_command(crew_type)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
record: Whether to record LLM responses
|
||||
replay: Whether to replay from recorded LLM responses
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
if record:
|
||||
command.append("--record")
|
||||
if replay:
|
||||
command.append("--replay")
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
@@ -2,8 +2,7 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
from functools import reduce
|
||||
from inspect import isfunction, ismethod
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -11,7 +10,6 @@ from rich.console import Console
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -252,11 +250,11 @@ def write_env_file(folder_path, env_vars):
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
"""Get the crew instances from the a file."""
|
||||
crew_instances = []
|
||||
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
"""Get the crew instance from the crew.py file."""
|
||||
try:
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
if crew_path in files:
|
||||
@@ -273,10 +271,12 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
|
||||
attr = getattr(module, attr_name)
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
if callable(attr) and hasattr(attr, "crew"):
|
||||
crew_instance = attr().crew()
|
||||
return crew_instance
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
continue
|
||||
@@ -286,6 +286,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
@@ -299,6 +300,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
if require:
|
||||
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if require:
|
||||
@@ -306,36 +308,4 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_crew_instance(module_attr) -> Crew | None:
|
||||
if (
|
||||
callable(module_attr)
|
||||
and hasattr(module_attr, "is_crew_class")
|
||||
and module_attr.is_crew_class
|
||||
):
|
||||
return module_attr().crew()
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
elif isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
crew_instances: list[Crew] = []
|
||||
|
||||
if crew_instance := get_crew_instance(module_attr):
|
||||
crew_instances.append(crew_instance)
|
||||
|
||||
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
|
||||
instance = module_attr()
|
||||
for attr_name in dir(instance):
|
||||
attr = getattr(instance, attr_name)
|
||||
if crew_instance := get_crew_instance(attr):
|
||||
crew_instances.append(crew_instance)
|
||||
return crew_instances
|
||||
|
||||
@@ -6,17 +6,7 @@ import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -34,7 +24,6 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
@@ -80,7 +69,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
class Crew(BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||
|
||||
@@ -244,15 +233,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
record_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to record LLM responses for later replay.",
|
||||
)
|
||||
replay_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to replay from recorded LLM responses without making network calls.",
|
||||
)
|
||||
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
@@ -324,9 +304,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self)
|
||||
if self.external_memory
|
||||
else None
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
)
|
||||
|
||||
self._long_term_memory = self.long_term_memory
|
||||
@@ -355,7 +333,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
embedder=self.embedder,
|
||||
collection_name="crew",
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
@@ -642,19 +619,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if self.record_mode and self.replay_mode:
|
||||
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
|
||||
|
||||
if self.record_mode or self.replay_mode:
|
||||
from crewai.utilities.llm_response_cache_handler import (
|
||||
LLMResponseCacheHandler,
|
||||
)
|
||||
self._llm_response_cache_handler = LLMResponseCacheHandler()
|
||||
if self.record_mode:
|
||||
self._llm_response_cache_handler.start_recording()
|
||||
elif self.replay_mode:
|
||||
self._llm_response_cache_handler.start_replaying()
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
@@ -673,12 +637,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
if hasattr(agent, "llm") and agent.llm:
|
||||
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
|
||||
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
@@ -1315,9 +1273,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self.max_rpm:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
self._llm_response_cache_handler.stop()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
@@ -1414,6 +1369,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
else:
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
self._logger.log("info", f"{command_type} memory has been reset")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
@@ -1434,14 +1391,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if system is not None:
|
||||
try:
|
||||
system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1470,11 +1421,5 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
try:
|
||||
memory_system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
class FlowTrackable(BaseModel):
|
||||
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
|
||||
Flow instance that created a Crew or Agent.
|
||||
|
||||
Automatically finds and stores a reference to the parent Flow instance by
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
|
||||
frame = inspect.currentframe()
|
||||
|
||||
try:
|
||||
if frame is None:
|
||||
return self
|
||||
|
||||
frame = frame.f_back
|
||||
for _ in range(max_depth):
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
candidate = frame.f_locals.get("self")
|
||||
if isinstance(candidate, Flow):
|
||||
self.parent_flow = candidate
|
||||
break
|
||||
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame
|
||||
|
||||
return self
|
||||
@@ -41,6 +41,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
@@ -62,7 +63,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
def _add_sources(self):
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -13,7 +13,6 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -81,7 +80,7 @@ class LiteAgentOutput(BaseModel):
|
||||
return self.raw
|
||||
|
||||
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
class LiteAgent(BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
|
||||
@@ -163,7 +162,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
|
||||
@@ -296,7 +296,6 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self._response_cache_handler = None
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -870,43 +869,25 @@ class LLM(BaseLLM):
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
if self._response_cache_handler and self._response_cache_handler.is_replaying():
|
||||
cached_response = self._response_cache_handler.get_cached_response(
|
||||
self.model, messages
|
||||
)
|
||||
if cached_response:
|
||||
# Emit completion event for the cached response
|
||||
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
|
||||
return cached_response
|
||||
|
||||
# --- 6) Set up callbacks if provided
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 7) Prepare parameters for the completion call
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 8) Make the completion call and handle response
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
response = self._handle_streaming_response(
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
response = self._handle_non_streaming_response(
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
if (self._response_cache_handler and
|
||||
self._response_cache_handler.is_recording() and
|
||||
isinstance(response, str)):
|
||||
self._response_cache_handler.cache_response(
|
||||
self.model, messages, response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1126,18 +1107,3 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def set_response_cache_handler(self, handler):
|
||||
"""
|
||||
Sets the response cache handler for record/replay functionality.
|
||||
|
||||
Args:
|
||||
handler: An instance of LLMResponseCacheHandler.
|
||||
"""
|
||||
self._response_cache_handler = handler
|
||||
|
||||
def clear_response_cache_handler(self):
|
||||
"""
|
||||
Clears the response cache handler.
|
||||
"""
|
||||
self._response_cache_handler = None
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheStorage:
|
||||
"""
|
||||
SQLite storage for caching LLM responses.
|
||||
Used for offline record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._connection_pool: Dict[int, sqlite3.Connection] = {}
|
||||
self._initialize_db()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Gets a connection from the connection pool or creates a new one.
|
||||
Uses thread-local storage to ensure thread safety.
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
if thread_id not in self._connection_pool:
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.execute("PRAGMA journal_mode = WAL")
|
||||
self._connection_pool[thread_id] = conn
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to create SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
return self._connection_pool[thread_id]
|
||||
|
||||
def _close_connections(self) -> None:
|
||||
"""
|
||||
Closes all connections in the connection pool.
|
||||
"""
|
||||
for thread_id, conn in list(self._connection_pool.items()):
|
||||
try:
|
||||
conn.close()
|
||||
del self._connection_pool[thread_id]
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to close SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
|
||||
def _initialize_db(self) -> None:
|
||||
"""
|
||||
Initializes the SQLite database and creates the llm_response_cache table
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to initialize database: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Computes a hash for the request based on the model and messages.
|
||||
This hash is used as the key for caching.
|
||||
|
||||
Sensitive information like API keys should not be included in the hash.
|
||||
"""
|
||||
try:
|
||||
message_str = json.dumps(messages, sort_keys=True)
|
||||
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
|
||||
return request_hash
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to compute request hash: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Adds a response to the cache.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO llm_response_cache
|
||||
(request_hash, model, messages, response)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
request_hash,
|
||||
model,
|
||||
messages_json,
|
||||
response,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to add response to cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when adding response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a response from the cache based on the model and messages.
|
||||
Returns None if not found.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT response
|
||||
FROM llm_response_cache
|
||||
WHERE request_hash = ?
|
||||
""",
|
||||
(request_hash,),
|
||||
)
|
||||
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to retrieve response from cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when retrieving response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Deletes all records from the llm_response_cache table.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to clear cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def cleanup_expired_cache(self, max_age_days: int = 7) -> None:
|
||||
"""
|
||||
Removes cache entries older than the specified number of days.
|
||||
|
||||
This method helps maintain the cache size and ensures that only recent
|
||||
responses are kept, which is important for keeping the cache relevant
|
||||
and preventing it from growing too large over time.
|
||||
|
||||
Args:
|
||||
max_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
If set to 0, all entries will be deleted.
|
||||
Must be a non-negative integer.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_age_days is not a non-negative integer.
|
||||
"""
|
||||
if not isinstance(max_age_days, int) or max_age_days < 0:
|
||||
error_msg = "max_age_days must be a non-negative integer"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if max_age_days <= 0:
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
deleted_count = cursor.rowcount
|
||||
logger.info("Deleting all cache entries (max_age_days <= 0)")
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM llm_response_cache
|
||||
WHERE timestamp < datetime('now', ? || ' days')
|
||||
""",
|
||||
(f"-{max_age_days}",)
|
||||
)
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE: Removed {deleted_count} expired cache entries",
|
||||
color="green",
|
||||
)
|
||||
logger.info(f"Removed {deleted_count} expired cache entries")
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to cleanup expired cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM llm_response_cache")
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT model, COUNT(*) FROM llm_response_cache GROUP BY model")
|
||||
model_counts = {model: count for model, count in cursor.fetchall()}
|
||||
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM llm_response_cache")
|
||||
oldest, newest = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"total_entries": total_count,
|
||||
"entries_by_model": model_counts,
|
||||
"oldest_entry": oldest,
|
||||
"newest_entry": newest,
|
||||
}
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to get cache stats: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return {"error": str(e)}
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Closes all connections when the object is garbage collected.
|
||||
"""
|
||||
self._close_connections()
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
@@ -15,8 +14,6 @@ from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -31,10 +28,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
|
||||
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import ( # noqa: E402
|
||||
BatchSpanProcessor,
|
||||
SpanExportResult,
|
||||
)
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402
|
||||
from opentelemetry.trace import Span, Status, StatusCode # noqa: E402
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,15 +36,6 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
def export(self, spans) -> SpanExportResult:
|
||||
try:
|
||||
return super().export(spans)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return SpanExportResult.FAILURE
|
||||
|
||||
|
||||
class Telemetry:
|
||||
"""A class to handle anonymous telemetry for the crewai package.
|
||||
|
||||
@@ -79,7 +64,7 @@ class Telemetry:
|
||||
self.provider = TracerProvider(resource=self.resource)
|
||||
|
||||
processor = BatchSpanProcessor(
|
||||
SafeOTLPSpanExporter(
|
||||
OTLPSpanExporter(
|
||||
endpoint=f"{CREWAI_TELEMETRY_BASE_URL}/v1/traces",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheHandler:
|
||||
"""
|
||||
Handler for the LLM response cache storage.
|
||||
Used for record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_age_days: int = 7) -> None:
|
||||
"""
|
||||
Initializes the LLM response cache handler.
|
||||
|
||||
Args:
|
||||
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
"""
|
||||
self.storage = LLMResponseCacheStorage()
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
self.max_cache_age_days = max_cache_age_days
|
||||
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
|
||||
|
||||
def start_recording(self) -> None:
|
||||
"""
|
||||
Starts recording LLM responses.
|
||||
"""
|
||||
self._recording = True
|
||||
self._replaying = False
|
||||
logger.info("Started recording LLM responses")
|
||||
|
||||
def start_replaying(self) -> None:
|
||||
"""
|
||||
Starts replaying LLM responses from the cache.
|
||||
"""
|
||||
self._recording = False
|
||||
self._replaying = True
|
||||
logger.info("Started replaying LLM responses from cache")
|
||||
|
||||
try:
|
||||
stats = self.storage.get_cache_stats()
|
||||
logger.info(f"Cache statistics: {stats}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cache statistics: {e}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stops recording or replaying.
|
||||
"""
|
||||
was_recording = self._recording
|
||||
was_replaying = self._replaying
|
||||
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
|
||||
if was_recording:
|
||||
logger.info("Stopped recording LLM responses")
|
||||
if was_replaying:
|
||||
logger.info("Stopped replaying LLM responses")
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
"""
|
||||
Returns whether recording is active.
|
||||
"""
|
||||
return self._recording
|
||||
|
||||
def is_replaying(self) -> bool:
|
||||
"""
|
||||
Returns whether replaying is active.
|
||||
"""
|
||||
return self._replaying
|
||||
|
||||
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Caches an LLM response if recording is active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
response: The response from the LLM.
|
||||
"""
|
||||
if not self._recording:
|
||||
return
|
||||
|
||||
try:
|
||||
self.storage.add(model, messages, response)
|
||||
logger.debug(f"Cached response for model {model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache response: {e}")
|
||||
|
||||
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a cached LLM response if replaying is active.
|
||||
Returns None if not found or if replaying is not active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
|
||||
Returns:
|
||||
The cached response, or None if not found or if replaying is not active.
|
||||
"""
|
||||
if not self._replaying:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.storage.get(model, messages)
|
||||
if response:
|
||||
logger.debug(f"Retrieved cached response for model {model}")
|
||||
else:
|
||||
logger.debug(f"No cached response found for model {model}")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve cached response: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""
|
||||
Clears the LLM response cache.
|
||||
"""
|
||||
try:
|
||||
self.storage.delete_all()
|
||||
logger.info("Cleared LLM response cache")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
|
||||
def cleanup_expired_cache(self) -> None:
|
||||
"""
|
||||
Removes cache entries older than the maximum age.
|
||||
"""
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired cache: {e}")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
return self.storage.get_cache_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
File diff suppressed because one or more lines are too long
@@ -18,7 +18,6 @@ from crewai.cli.cli import (
|
||||
train,
|
||||
version,
|
||||
)
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -56,133 +55,81 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
_mock = mock.Mock(spec=Crew, name="test_crew")
|
||||
_mock.name = "test_crew"
|
||||
return _mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew:
|
||||
yield mock_get_crew
|
||||
|
||||
|
||||
def test_reset_all_memories(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_all_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Reset memories command has been completed."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert result.output == "All memories have been reset.\n"
|
||||
|
||||
|
||||
def test_reset_short_term_memories(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_short_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Short term memory has been reset." in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert result.output == "Short term memory has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_entity_memories(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_entity_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-e"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert result.output == "Entity memory has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_long_term_memories(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_long_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert result.output == "Long term memory has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_kickoff_outputs(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-k"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Latest Kickoff outputs stored has been reset."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_multiple_memory_flags(mock_get_crews, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_multiple_memory_flags(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Long term memory has been reset.\n"
|
||||
f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
# Check that reset_memories was called twice with the correct arguments
|
||||
assert mock_crew.reset_memories.call_count == 2
|
||||
mock_crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
result.output
|
||||
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||
)
|
||||
|
||||
|
||||
def test_reset_knowledge(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.name})] Knowledge has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
|
||||
crews = []
|
||||
for crew_id in ["id-1234", "id-5678"]:
|
||||
mock_crew = mock.Mock(spec=Crew)
|
||||
mock_crew.name = None
|
||||
mock_crew.id = crew_id
|
||||
crews.append(mock_crew)
|
||||
|
||||
mock_get_crews.return_value = crews
|
||||
|
||||
# Run the command
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_knowledge(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
|
||||
call_count = 0
|
||||
for crew in crews:
|
||||
call_count += 1
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.id})] Knowledge has been reset." in result.output
|
||||
|
||||
assert call_count == 2, "reset_memories should have been called twice"
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
|
||||
@@ -3,13 +3,12 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
@@ -24,20 +23,17 @@ def in_temp_dir():
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
TokenManager().save_tokens("test-token", 36000)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command.create("test-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with (
|
||||
patch.object(tool_command, "login") as mock_login,
|
||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
||||
):
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
@@ -51,12 +47,15 @@ def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
def test_install_success(mock_get, mock_subprocess_run):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
@@ -66,9 +65,11 @@ def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
@@ -85,42 +86,54 @@ def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||
def test_install_tool_not_found(mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "No tool found with this name" in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("non-existent-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get, capsys, tool_command):
|
||||
def test_install_api_error(mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to get tool details" in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("error-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
def test_publish_when_not_in_sync(mock_is_synced):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit):
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Local changes need to be resolved before publishing" in output
|
||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@@ -144,13 +157,13 @@ def test_publish_when_not_in_sync_and_force(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -192,13 +205,13 @@ def test_publish_success(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -238,21 +251,24 @@ def test_publish_failure(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@@ -274,8 +290,6 @@ def test_publish_api_error(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -283,9 +297,14 @@ def test_publish_api_error(
|
||||
mock_response.ok = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Request to Enterprise API failed" in output
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
@@ -17,7 +17,6 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow import Flow, listen, start
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
@@ -2165,6 +2164,7 @@ def test_tools_with_custom_caching():
|
||||
with patch.object(
|
||||
CacheHandler, "add", wraps=crew._cache_handler.add
|
||||
) as add_to_cache:
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Check that add_to_cache was called exactly twice
|
||||
@@ -4351,35 +4351,3 @@ def test_crew_copy_with_memory():
|
||||
raise e # Re-raise other validation errors
|
||||
except Exception as e:
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
assert crew.parent_flow is None
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
return Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(
|
||||
description="Task 1", expected_output="output", agent=researcher
|
||||
),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
|
||||
flow = MyFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.parent_flow is flow
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
handler = LLMResponseCacheHandler()
|
||||
handler.storage.add = MagicMock()
|
||||
handler.storage.get = MagicMock()
|
||||
return handler
|
||||
|
||||
|
||||
def create_mock_response(content):
|
||||
"""Create a properly structured mock response object for litellm.completion"""
|
||||
message = SimpleNamespace(content=content)
|
||||
choice = SimpleNamespace(message=message)
|
||||
response = SimpleNamespace(choices=[choice])
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_recording(handler):
|
||||
handler.start_recording()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
handler.storage.add.assert_called_once_with(
|
||||
"gpt-4o-mini", messages, "Hello, human!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replaying(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = "Cached response"
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Cached response"
|
||||
|
||||
mock_completion.assert_not_called()
|
||||
|
||||
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replay_fallback(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = None
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_error_handling():
|
||||
"""Test that errors during cache operations are handled gracefully."""
|
||||
handler = LLMResponseCacheHandler()
|
||||
|
||||
handler.storage.add = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
handler.storage.get = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
|
||||
handler.start_recording()
|
||||
|
||||
handler.cache_response("model", [{"role": "user", "content": "test"}], "response")
|
||||
|
||||
handler.start_replaying()
|
||||
|
||||
assert handler.get_cached_response("model", [{"role": "user", "content": "test"}]) is None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_expiration():
|
||||
"""Test that cache expiration works correctly."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
storage = LLMResponseCacheStorage(":memory:")
|
||||
|
||||
original_get_connection = storage._get_connection
|
||||
storage._get_connection = lambda: conn
|
||||
|
||||
try:
|
||||
model = "test-model"
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
response = "test response"
|
||||
storage.add(model, messages, response)
|
||||
|
||||
assert storage.get(model, messages) == response
|
||||
|
||||
storage.cleanup_expired_cache(max_age_days=0)
|
||||
|
||||
assert storage.get(model, messages) is None
|
||||
finally:
|
||||
storage._get_connection = original_get_connection
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_concurrent_cache_access():
|
||||
"""Test that concurrent cache access works correctly."""
|
||||
pytest.skip("SQLite in-memory databases are not shared between threads")
|
||||
|
||||
|
||||
# storage = LLMResponseCacheStorage(temp_db.name)
|
||||
@@ -1,93 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_recording_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the recording functionality",
|
||||
backstory="A test agent for recording LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_recording.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_replay_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the replay functionality",
|
||||
backstory="A test agent for replaying LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
replay_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_replaying.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_record_replay_flags_conflict():
|
||||
with pytest.raises(ValueError):
|
||||
crew = Crew(
|
||||
agents=[],
|
||||
tasks=[],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
replay_mode=True,
|
||||
)
|
||||
crew.kickoff()
|
||||
@@ -1,69 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var,value,expected_ready",
|
||||
[
|
||||
("OTEL_SDK_DISABLED", "true", False),
|
||||
("OTEL_SDK_DISABLED", "TRUE", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "true", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "TRUE", False),
|
||||
("OTEL_SDK_DISABLED", "false", True),
|
||||
("CREWAI_DISABLE_TELEMETRY", "false", True),
|
||||
],
|
||||
)
|
||||
def test_telemetry_environment_variables(env_var, value, expected_ready):
|
||||
"""Test telemetry state with different environment variable configurations."""
|
||||
with patch.dict(os.environ, {env_var: value}):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is expected_ready
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default():
|
||||
"""Test that telemetry is enabled by default."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is True
|
||||
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
|
||||
@patch("crewai.telemetry.telemetry.logger.error")
|
||||
@patch(
|
||||
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter.export",
|
||||
side_effect=Exception("Test exception"),
|
||||
)
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock):
|
||||
error = Exception("Test exception")
|
||||
export_mock.side_effect = error
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
with tracer.start_as_current_span("test-span"):
|
||||
agent = Agent(
|
||||
role="agent",
|
||||
llm="gpt-4o-mini",
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
export_mock.assert_called_once()
|
||||
logger_mock.assert_called_once_with(error)
|
||||
@@ -1,16 +1,13 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
||||
|
||||
|
||||
@@ -258,60 +255,3 @@ async def test_lite_agent_returns_usage_metrics_async():
|
||||
assert "21 million" in result.raw or "37 million" in result.raw
|
||||
assert result.usage_metrics is not None
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
class TestFlow(Flow):
|
||||
"""A test flow that creates and runs an agent."""
|
||||
|
||||
def __init__(self, llm, tools):
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
|
||||
def verify_agent_parent_flow(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct parent flow."""
|
||||
assert result.parent_flow is flow
|
||||
assert agent is not None
|
||||
assert agent.parent_flow is flow
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow():
|
||||
captured_agent = None
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=mock_llm,
|
||||
tools=[WebSearchTool()],
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
|
||||
result = flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -835,7 +835,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
{ name = "json5", specifier = ">=0.10.0" },
|
||||
{ name = "jsonref", specifier = ">=1.1.0" },
|
||||
{ name = "litellm", specifier = "==1.68.0" },
|
||||
{ name = "litellm", specifier = "==1.67.1" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.94" },
|
||||
{ name = "openai", specifier = ">=1.13.3" },
|
||||
{ name = "openpyxl", specifier = ">=3.1.5" },
|
||||
@@ -2387,7 +2387,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.68.0"
|
||||
version = "1.67.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@@ -2402,9 +2402,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/22/138545b646303ca3f4841b69613c697b9d696322a1386083bb70bcbba60b/litellm-1.68.0.tar.gz", hash = "sha256:9fb24643db84dfda339b64bafca505a2eef857477afbc6e98fb56512c24dbbfa", size = 7314051 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/54/a4/bb3e9ae59e5a9857443448de7c04752630dc84cddcbd8cee037c0976f44f/litellm-1.67.1.tar.gz", hash = "sha256:78eab1bd3d759ec13aa4a05864356a4a4725634e78501db609d451bf72150ee7", size = 7242044 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/af/1e344bc8aee41445272e677d802b774b1f8b34bdc3bb5697ba30f0fb5d52/litellm-1.68.0-py3-none-any.whl", hash = "sha256:3bca38848b1a5236b11aa6b70afa4393b60880198c939e582273f51a542d4759", size = 7684460 },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/86/c14d3c24ae13c08296d068e6f79fd4bd17a0a07bddbda94990b87c35d20e/litellm-1.67.1-py3-none-any.whl", hash = "sha256:8fff5b2a16b63bb594b94d6c071ad0f27d3d8cd4348bd5acea2fd40c8e0c11e8", size = 7607266 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user