mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-27 09:48:30 +00:00
Compare commits
12 Commits
lg-pytest-
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b69e328752 | ||
|
|
223683d8bd | ||
|
|
62de5a7989 | ||
|
|
5cccf4f7f5 | ||
|
|
dd5f170f45 | ||
|
|
6e8e066091 | ||
|
|
d5dfd5a1f5 | ||
|
|
dabf02a90d | ||
|
|
2912c93d77 | ||
|
|
17474a3a0c | ||
|
|
f89c2bfb7e | ||
|
|
2902201bfa |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -31,4 +31,4 @@ jobs:
|
||||
run: uv sync --dev --all-extras
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests -vv
|
||||
run: uv run pytest --block-network --timeout=60 -vv
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.67.1",
|
||||
"litellm==1.68.0",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
@@ -85,6 +85,8 @@ dev-dependencies = [
|
||||
"pytest-asyncio>=0.23.7",
|
||||
"pytest-subprocess>=1.5.2",
|
||||
"pytest-recording>=0.13.2",
|
||||
"pytest-randomly>=3.16.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -201,9 +201,22 @@ def install(context):
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
@click.option(
|
||||
"--record",
|
||||
is_flag=True,
|
||||
help="Record LLM responses for later replay",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
is_flag=True,
|
||||
help="Replay from recorded LLM responses without making network calls",
|
||||
)
|
||||
def run(record: bool = False, replay: bool = False):
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
if record and replay:
|
||||
raise click.UsageError("Cannot use --record and --replay simultaneously")
|
||||
click.echo("Running the Crew")
|
||||
run_crew(record=record, replay=replay)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -2,7 +2,7 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crew
|
||||
from crewai.cli.utils import get_crews
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -26,35 +26,47 @@ def reset_memories_command(
|
||||
"""
|
||||
|
||||
try:
|
||||
crew = get_crew()
|
||||
if not crew:
|
||||
raise ValueError("No crew found.")
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo("All memories have been reset.")
|
||||
return
|
||||
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
return
|
||||
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo("Long term memory has been reset.")
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo("Knowledge has been reset.")
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
|
||||
)
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -14,13 +14,17 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
record (bool, optional): Whether to record LLM responses. Defaults to False.
|
||||
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -44,17 +48,24 @@ def run_crew() -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type)
|
||||
execute_command(crew_type, record, replay)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
record: Whether to record LLM responses
|
||||
replay: Whether to replay from recorded LLM responses
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
if record:
|
||||
command.append("--record")
|
||||
if replay:
|
||||
command.append("--replay")
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
@@ -2,7 +2,8 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List
|
||||
from inspect import isfunction, ismethod
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -10,6 +11,7 @@ from rich.console import Console
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -250,11 +252,11 @@ def write_env_file(folder_path, env_vars):
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
"""Get the crew instance from the crew.py file."""
|
||||
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
"""Get the crew instances from the a file."""
|
||||
crew_instances = []
|
||||
try:
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
if crew_path in files:
|
||||
@@ -271,12 +273,10 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
try:
|
||||
if callable(attr) and hasattr(attr, "crew"):
|
||||
crew_instance = attr().crew()
|
||||
return crew_instance
|
||||
module_attr = getattr(module, attr_name)
|
||||
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
continue
|
||||
@@ -286,7 +286,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
@@ -300,7 +299,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
if require:
|
||||
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if require:
|
||||
@@ -308,4 +306,36 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_crew_instance(module_attr) -> Crew | None:
|
||||
if (
|
||||
callable(module_attr)
|
||||
and hasattr(module_attr, "is_crew_class")
|
||||
and module_attr.is_crew_class
|
||||
):
|
||||
return module_attr().crew()
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
elif isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
crew_instances: list[Crew] = []
|
||||
|
||||
if crew_instance := get_crew_instance(module_attr):
|
||||
crew_instances.append(crew_instance)
|
||||
|
||||
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
|
||||
instance = module_attr()
|
||||
for attr_name in dir(instance):
|
||||
attr = getattr(instance, attr_name)
|
||||
if crew_instance := get_crew_instance(attr):
|
||||
crew_instances.append(crew_instance)
|
||||
return crew_instances
|
||||
|
||||
@@ -6,7 +6,17 @@ import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -24,6 +34,7 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
@@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
class Crew(BaseModel):
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||
|
||||
@@ -233,6 +244,15 @@ class Crew(BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
record_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to record LLM responses for later replay.",
|
||||
)
|
||||
replay_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to replay from recorded LLM responses without making network calls.",
|
||||
)
|
||||
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
@@ -304,7 +324,9 @@ class Crew(BaseModel):
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
self.external_memory.set_crew(self)
|
||||
if self.external_memory
|
||||
else None
|
||||
)
|
||||
|
||||
self._long_term_memory = self.long_term_memory
|
||||
@@ -333,6 +355,7 @@ class Crew(BaseModel):
|
||||
embedder=self.embedder,
|
||||
collection_name="crew",
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
@@ -619,6 +642,19 @@ class Crew(BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if self.record_mode and self.replay_mode:
|
||||
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
|
||||
|
||||
if self.record_mode or self.replay_mode:
|
||||
from crewai.utilities.llm_response_cache_handler import (
|
||||
LLMResponseCacheHandler,
|
||||
)
|
||||
self._llm_response_cache_handler = LLMResponseCacheHandler()
|
||||
if self.record_mode:
|
||||
self._llm_response_cache_handler.start_recording()
|
||||
elif self.replay_mode:
|
||||
self._llm_response_cache_handler.start_replaying()
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
@@ -637,6 +673,12 @@ class Crew(BaseModel):
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
if hasattr(agent, "llm") and agent.llm:
|
||||
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
|
||||
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
@@ -1273,6 +1315,9 @@ class Crew(BaseModel):
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self.max_rpm:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
self._llm_response_cache_handler.stop()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
@@ -1369,8 +1414,6 @@ class Crew(BaseModel):
|
||||
else:
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
self._logger.log("info", f"{command_type} memory has been reset")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
@@ -1391,8 +1434,14 @@ class Crew(BaseModel):
|
||||
if system is not None:
|
||||
try:
|
||||
system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1421,5 +1470,11 @@ class Crew(BaseModel):
|
||||
|
||||
try:
|
||||
memory_system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
44
src/crewai/flow/flow_trackable.py
Normal file
44
src/crewai/flow/flow_trackable.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
class FlowTrackable(BaseModel):
|
||||
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
|
||||
Flow instance that created a Crew or Agent.
|
||||
|
||||
Automatically finds and stores a reference to the parent Flow instance by
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
|
||||
frame = inspect.currentframe()
|
||||
|
||||
try:
|
||||
if frame is None:
|
||||
return self
|
||||
|
||||
frame = frame.f_back
|
||||
for _ in range(max_depth):
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
candidate = frame.f_locals.get("self")
|
||||
if isinstance(candidate, Flow):
|
||||
self.parent_flow = candidate
|
||||
break
|
||||
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame
|
||||
|
||||
return self
|
||||
@@ -41,7 +41,6 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
@@ -63,7 +62,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
return results
|
||||
|
||||
def _add_sources(self):
|
||||
def add_sources(self):
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel):
|
||||
return self.raw
|
||||
|
||||
|
||||
class LiteAgent(BaseModel):
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
|
||||
@@ -162,7 +163,7 @@ class LiteAgent(BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
|
||||
@@ -296,6 +296,7 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self._response_cache_handler = None
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -869,25 +870,43 @@ class LLM(BaseLLM):
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
if self._response_cache_handler and self._response_cache_handler.is_replaying():
|
||||
cached_response = self._response_cache_handler.get_cached_response(
|
||||
self.model, messages
|
||||
)
|
||||
if cached_response:
|
||||
# Emit completion event for the cached response
|
||||
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
|
||||
return cached_response
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
# --- 6) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
# --- 7) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 7) Make the completion call and handle response
|
||||
# --- 8) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
response = self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
response = self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
if (self._response_cache_handler and
|
||||
self._response_cache_handler.is_recording() and
|
||||
isinstance(response, str)):
|
||||
self._response_cache_handler.cache_response(
|
||||
self.model, messages, response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1107,3 +1126,18 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def set_response_cache_handler(self, handler):
|
||||
"""
|
||||
Sets the response cache handler for record/replay functionality.
|
||||
|
||||
Args:
|
||||
handler: An instance of LLMResponseCacheHandler.
|
||||
"""
|
||||
self._response_cache_handler = handler
|
||||
|
||||
def clear_response_cache_handler(self):
|
||||
"""
|
||||
Clears the response cache handler.
|
||||
"""
|
||||
self._response_cache_handler = None
|
||||
|
||||
314
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
314
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheStorage:
|
||||
"""
|
||||
SQLite storage for caching LLM responses.
|
||||
Used for offline record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._connection_pool: Dict[int, sqlite3.Connection] = {}
|
||||
self._initialize_db()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Gets a connection from the connection pool or creates a new one.
|
||||
Uses thread-local storage to ensure thread safety.
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
if thread_id not in self._connection_pool:
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.execute("PRAGMA journal_mode = WAL")
|
||||
self._connection_pool[thread_id] = conn
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to create SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
return self._connection_pool[thread_id]
|
||||
|
||||
def _close_connections(self) -> None:
|
||||
"""
|
||||
Closes all connections in the connection pool.
|
||||
"""
|
||||
for thread_id, conn in list(self._connection_pool.items()):
|
||||
try:
|
||||
conn.close()
|
||||
del self._connection_pool[thread_id]
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to close SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
|
||||
def _initialize_db(self) -> None:
|
||||
"""
|
||||
Initializes the SQLite database and creates the llm_response_cache table
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to initialize database: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Computes a hash for the request based on the model and messages.
|
||||
This hash is used as the key for caching.
|
||||
|
||||
Sensitive information like API keys should not be included in the hash.
|
||||
"""
|
||||
try:
|
||||
message_str = json.dumps(messages, sort_keys=True)
|
||||
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
|
||||
return request_hash
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to compute request hash: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Adds a response to the cache.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO llm_response_cache
|
||||
(request_hash, model, messages, response)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
request_hash,
|
||||
model,
|
||||
messages_json,
|
||||
response,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to add response to cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when adding response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a response from the cache based on the model and messages.
|
||||
Returns None if not found.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT response
|
||||
FROM llm_response_cache
|
||||
WHERE request_hash = ?
|
||||
""",
|
||||
(request_hash,),
|
||||
)
|
||||
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to retrieve response from cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when retrieving response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Deletes all records from the llm_response_cache table.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to clear cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def cleanup_expired_cache(self, max_age_days: int = 7) -> None:
|
||||
"""
|
||||
Removes cache entries older than the specified number of days.
|
||||
|
||||
This method helps maintain the cache size and ensures that only recent
|
||||
responses are kept, which is important for keeping the cache relevant
|
||||
and preventing it from growing too large over time.
|
||||
|
||||
Args:
|
||||
max_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
If set to 0, all entries will be deleted.
|
||||
Must be a non-negative integer.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_age_days is not a non-negative integer.
|
||||
"""
|
||||
if not isinstance(max_age_days, int) or max_age_days < 0:
|
||||
error_msg = "max_age_days must be a non-negative integer"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if max_age_days <= 0:
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
deleted_count = cursor.rowcount
|
||||
logger.info("Deleting all cache entries (max_age_days <= 0)")
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM llm_response_cache
|
||||
WHERE timestamp < datetime('now', ? || ' days')
|
||||
""",
|
||||
(f"-{max_age_days}",)
|
||||
)
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE: Removed {deleted_count} expired cache entries",
|
||||
color="green",
|
||||
)
|
||||
logger.info(f"Removed {deleted_count} expired cache entries")
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to cleanup expired cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM llm_response_cache")
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT model, COUNT(*) FROM llm_response_cache GROUP BY model")
|
||||
model_counts = {model: count for model, count in cursor.fetchall()}
|
||||
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM llm_response_cache")
|
||||
oldest, newest = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"total_entries": total_count,
|
||||
"entries_by_model": model_counts,
|
||||
"oldest_entry": oldest,
|
||||
"newest_entry": newest,
|
||||
}
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to get cache stats: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return {"error": str(e)}
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Closes all connections when the object is garbage collected.
|
||||
"""
|
||||
self._close_connections()
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
@@ -14,6 +15,8 @@ from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -28,7 +31,10 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
|
||||
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import ( # noqa: E402
|
||||
BatchSpanProcessor,
|
||||
SpanExportResult,
|
||||
)
|
||||
from opentelemetry.trace import Span, Status, StatusCode # noqa: E402
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,6 +42,15 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
def export(self, spans) -> SpanExportResult:
|
||||
try:
|
||||
return super().export(spans)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return SpanExportResult.FAILURE
|
||||
|
||||
|
||||
class Telemetry:
|
||||
"""A class to handle anonymous telemetry for the crewai package.
|
||||
|
||||
@@ -64,7 +79,7 @@ class Telemetry:
|
||||
self.provider = TracerProvider(resource=self.resource)
|
||||
|
||||
processor = BatchSpanProcessor(
|
||||
OTLPSpanExporter(
|
||||
SafeOTLPSpanExporter(
|
||||
endpoint=f"{CREWAI_TELEMETRY_BASE_URL}/v1/traces",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -70,7 +70,12 @@ class CrewAIEventsBus:
|
||||
for event_type, handlers in self._handlers.items():
|
||||
if isinstance(event, event_type):
|
||||
for handler in handlers:
|
||||
handler(source, event)
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
)
|
||||
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
|
||||
156
src/crewai/utilities/llm_response_cache_handler.py
Normal file
156
src/crewai/utilities/llm_response_cache_handler.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheHandler:
|
||||
"""
|
||||
Handler for the LLM response cache storage.
|
||||
Used for record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_age_days: int = 7) -> None:
|
||||
"""
|
||||
Initializes the LLM response cache handler.
|
||||
|
||||
Args:
|
||||
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
"""
|
||||
self.storage = LLMResponseCacheStorage()
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
self.max_cache_age_days = max_cache_age_days
|
||||
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
|
||||
|
||||
def start_recording(self) -> None:
|
||||
"""
|
||||
Starts recording LLM responses.
|
||||
"""
|
||||
self._recording = True
|
||||
self._replaying = False
|
||||
logger.info("Started recording LLM responses")
|
||||
|
||||
def start_replaying(self) -> None:
|
||||
"""
|
||||
Starts replaying LLM responses from the cache.
|
||||
"""
|
||||
self._recording = False
|
||||
self._replaying = True
|
||||
logger.info("Started replaying LLM responses from cache")
|
||||
|
||||
try:
|
||||
stats = self.storage.get_cache_stats()
|
||||
logger.info(f"Cache statistics: {stats}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cache statistics: {e}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stops recording or replaying.
|
||||
"""
|
||||
was_recording = self._recording
|
||||
was_replaying = self._replaying
|
||||
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
|
||||
if was_recording:
|
||||
logger.info("Stopped recording LLM responses")
|
||||
if was_replaying:
|
||||
logger.info("Stopped replaying LLM responses")
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
"""
|
||||
Returns whether recording is active.
|
||||
"""
|
||||
return self._recording
|
||||
|
||||
def is_replaying(self) -> bool:
|
||||
"""
|
||||
Returns whether replaying is active.
|
||||
"""
|
||||
return self._replaying
|
||||
|
||||
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Caches an LLM response if recording is active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
response: The response from the LLM.
|
||||
"""
|
||||
if not self._recording:
|
||||
return
|
||||
|
||||
try:
|
||||
self.storage.add(model, messages, response)
|
||||
logger.debug(f"Cached response for model {model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache response: {e}")
|
||||
|
||||
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a cached LLM response if replaying is active.
|
||||
Returns None if not found or if replaying is not active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
|
||||
Returns:
|
||||
The cached response, or None if not found or if replaying is not active.
|
||||
"""
|
||||
if not self._replaying:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.storage.get(model, messages)
|
||||
if response:
|
||||
logger.debug(f"Retrieved cached response for model {model}")
|
||||
else:
|
||||
logger.debug(f"No cached response found for model {model}")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve cached response: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""
|
||||
Clears the LLM response cache.
|
||||
"""
|
||||
try:
|
||||
self.storage.delete_all()
|
||||
logger.info("Cleared LLM response cache")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
|
||||
def cleanup_expired_cache(self) -> None:
|
||||
"""
|
||||
Removes cache entries older than the maximum age.
|
||||
"""
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired cache: {e}")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
return self.storage.get_cache_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
1899
tests/cassettes/test_docling_source.yaml
Normal file
1899
tests/cassettes/test_docling_source.yaml
Normal file
File diff suppressed because it is too large
Load Diff
3321
tests/cassettes/test_multiple_docling_sources.yaml
Normal file
3321
tests/cassettes/test_multiple_docling_sources.yaml
Normal file
File diff suppressed because it is too large
Load Diff
221
tests/cassettes/test_telemetry_fails_due_connect_timeout.yaml
Normal file
221
tests/cassettes/test_telemetry_fails_due_connect_timeout.yaml
Normal file
File diff suppressed because one or more lines are too long
@@ -18,6 +18,7 @@ from crewai.cli.cli import (
|
||||
train,
|
||||
version,
|
||||
)
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -55,81 +56,133 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_all_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
_mock = mock.Mock(spec=Crew, name="test_crew")
|
||||
_mock.name = "test_crew"
|
||||
return _mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew:
|
||||
yield mock_get_crew
|
||||
|
||||
|
||||
def test_reset_all_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert result.output == "All memories have been reset.\n"
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Reset memories command has been completed."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_short_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_short_term_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Short term memory has been reset." in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert result.output == "Short term memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_entity_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_entity_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-e"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert result.output == "Entity memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_long_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_long_term_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert result.output == "Long term memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_kickoff_outputs(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-k"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Latest Kickoff outputs stored has been reset."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_multiple_memory_flags(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_multiple_memory_flags(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Long term memory has been reset.\n"
|
||||
f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
# Check that reset_memories was called twice with the correct arguments
|
||||
assert mock_crew.reset_memories.call_count == 2
|
||||
mock_crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
result.output
|
||||
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||
)
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_knowledge(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_knowledge(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.name})] Knowledge has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
|
||||
crews = []
|
||||
for crew_id in ["id-1234", "id-5678"]:
|
||||
mock_crew = mock.Mock(spec=Crew)
|
||||
mock_crew.name = None
|
||||
mock_crew.id = crew_id
|
||||
crews.append(mock_crew)
|
||||
|
||||
mock_get_crews.return_value = crews
|
||||
|
||||
# Run the command
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
call_count = 0
|
||||
for crew in crews:
|
||||
call_count += 1
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.id})] Knowledge has been reset." in result.output
|
||||
|
||||
assert call_count == 2, "reset_memories should have been called twice"
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
|
||||
@@ -3,12 +3,13 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
@@ -23,17 +24,20 @@ def in_temp_dir():
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command = ToolCommand()
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
TokenManager().save_tokens("test-token", 36000)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
with (
|
||||
patch.object(tool_command, "login") as mock_login,
|
||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
||||
):
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
with in_temp_dir():
|
||||
tool_command.create("test-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
@@ -47,15 +51,12 @@ def test_create_success(mock_subprocess):
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(mock_get, mock_subprocess_run):
|
||||
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
@@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
@@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get):
|
||||
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("non-existent-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get):
|
||||
def test_install_api_error(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("error-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit):
|
||||
tool_command = ToolCommand()
|
||||
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
||||
output = capsys.readouterr().out
|
||||
assert "Local changes need to be resolved before publishing" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -205,13 +192,13 @@ def test_publish_success(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -251,25 +238,22 @@ def test_publish_failure(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@@ -290,6 +274,8 @@ def test_publish_api_error(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -297,14 +283,9 @@ def test_publish_api_error(
|
||||
mock_response.ok = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow import Flow, listen, start
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
@@ -42,29 +43,38 @@ from crewai.utilities.events.event_listener import EventListener
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
|
||||
ceo = Agent(
|
||||
role="CEO",
|
||||
goal="Make sure the writers in your company produce amazing content.",
|
||||
backstory="You're an long time CEO of a content creation agency with a Senior Writer on the team. You're now working on a new project and want to make sure the content produced is amazing.",
|
||||
allow_delegation=True,
|
||||
)
|
||||
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Senior Writer",
|
||||
goal="Write the best content about AI and AI agents.",
|
||||
backstory="You're a senior writer, specialized in technology, software engineering, AI and startups. You work as a freelancer and are now working on writing content for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
@pytest.fixture
|
||||
def ceo():
|
||||
return Agent(
|
||||
role="CEO",
|
||||
goal="Make sure the writers in your company produce amazing content.",
|
||||
backstory="You're an long time CEO of a content creation agency with a Senior Writer on the team. You're now working on a new project and want to make sure the content produced is amazing.",
|
||||
allow_delegation=True,
|
||||
)
|
||||
|
||||
|
||||
def test_crew_with_only_conditional_tasks_raises_error():
|
||||
@pytest.fixture
|
||||
def researcher():
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer():
|
||||
return Agent(
|
||||
role="Senior Writer",
|
||||
goal="Write the best content about AI and AI agents.",
|
||||
backstory="You're a senior writer, specialized in technology, software engineering, AI and startups. You work as a freelancer and are now working on writing content for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
|
||||
def test_crew_with_only_conditional_tasks_raises_error(researcher):
|
||||
"""Test that creating a crew with only conditional tasks raises an error."""
|
||||
|
||||
def condition_func(task_output: TaskOutput) -> bool:
|
||||
@@ -146,7 +156,9 @@ def test_crew_config_conditional_requirement():
|
||||
]
|
||||
|
||||
|
||||
def test_async_task_cannot_include_sequential_async_tasks_in_context():
|
||||
def test_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
researcher, writer
|
||||
):
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
async_execution=True,
|
||||
@@ -194,7 +206,7 @@ def test_async_task_cannot_include_sequential_async_tasks_in_context():
|
||||
pytest.fail("Unexpected ValidationError raised")
|
||||
|
||||
|
||||
def test_context_no_future_tasks():
|
||||
def test_context_no_future_tasks(researcher, writer):
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="output",
|
||||
@@ -258,7 +270,7 @@ def test_crew_config_with_wrong_keys():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_creation():
|
||||
def test_crew_creation(researcher, writer):
|
||||
tasks = [
|
||||
Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
|
||||
@@ -290,7 +302,7 @@ def test_crew_creation():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_sync_task_execution():
|
||||
def test_sync_task_execution(researcher, writer):
|
||||
from unittest.mock import patch
|
||||
|
||||
tasks = [
|
||||
@@ -331,7 +343,7 @@ def test_sync_task_execution():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_process():
|
||||
def test_hierarchical_process(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -352,7 +364,7 @@ def test_hierarchical_process():
|
||||
)
|
||||
|
||||
|
||||
def test_manager_llm_requirement_for_hierarchical_process():
|
||||
def test_manager_llm_requirement_for_hierarchical_process(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -367,7 +379,7 @@ def test_manager_llm_requirement_for_hierarchical_process():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_manager_agent_delegating_to_assigned_task_agent():
|
||||
def test_manager_agent_delegating_to_assigned_task_agent(researcher, writer):
|
||||
"""
|
||||
Test that the manager agent delegates to the assigned task agent.
|
||||
"""
|
||||
@@ -419,7 +431,7 @@ def test_manager_agent_delegating_to_assigned_task_agent():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_manager_agent_delegating_to_all_agents():
|
||||
def test_manager_agent_delegating_to_all_agents(researcher, writer):
|
||||
"""
|
||||
Test that the manager agent delegates to all agents when none are specified.
|
||||
"""
|
||||
@@ -529,7 +541,7 @@ def test_manager_agent_delegates_with_varied_role_cases():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents():
|
||||
def test_crew_with_delegating_agents(ceo, writer):
|
||||
tasks = [
|
||||
Task(
|
||||
description="Produce and amazing 1 paragraph draft of an article about AI Agents.",
|
||||
@@ -553,7 +565,7 @@ def test_crew_with_delegating_agents():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools():
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -615,7 +627,7 @@ def test_crew_with_delegating_agents_should_not_override_task_tools():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_agent_tools():
|
||||
def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -679,7 +691,7 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_task_tools_override_agent_tools():
|
||||
def test_task_tools_override_agent_tools(researcher):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -734,7 +746,7 @@ def test_task_tools_override_agent_tools():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_task_tools_override_agent_tools_with_allow_delegation():
|
||||
def test_task_tools_override_agent_tools_with_allow_delegation(researcher, writer):
|
||||
"""
|
||||
Test that task tools override agent tools while preserving delegation tools when allow_delegation=True
|
||||
"""
|
||||
@@ -817,7 +829,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_verbose_output(capsys):
|
||||
def test_crew_verbose_output(researcher, writer, capsys):
|
||||
tasks = [
|
||||
Task(
|
||||
description="Research AI advancements.",
|
||||
@@ -877,7 +889,7 @@ def test_crew_verbose_output(capsys):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_hitting_between_agents():
|
||||
def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
from unittest.mock import call, patch
|
||||
|
||||
from crewai.tools import tool
|
||||
@@ -1050,7 +1062,7 @@ def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_sequential_async_task_execution_completion():
|
||||
def test_sequential_async_task_execution_completion(researcher, writer):
|
||||
list_ideas = Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for an article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 important events.",
|
||||
@@ -1204,7 +1216,7 @@ async def test_crew_async_kickoff():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
async def test_async_task_execution_call_count():
|
||||
async def test_async_task_execution_call_count(researcher, writer):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
list_ideas = Task(
|
||||
@@ -1707,7 +1719,7 @@ def test_agents_do_not_get_delegation_tools_with_there_is_only_one_agent():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_sequential_crew_creation_tasks_without_agents():
|
||||
def test_sequential_crew_creation_tasks_without_agents(researcher):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -1757,7 +1769,7 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_crew_creation_tasks_with_agents():
|
||||
def test_hierarchical_crew_creation_tasks_with_agents(researcher, writer):
|
||||
"""
|
||||
Agents are not required for tasks in a hierarchical process but sometimes they are still added
|
||||
This test makes sure that the manager still delegates the task to the agent even if the agent is passed in the task
|
||||
@@ -1810,7 +1822,7 @@ def test_hierarchical_crew_creation_tasks_with_agents():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_crew_creation_tasks_with_async_execution():
|
||||
def test_hierarchical_crew_creation_tasks_with_async_execution(researcher, writer, ceo):
|
||||
"""
|
||||
Tests that async tasks in hierarchical crews are handled correctly with proper delegation tools
|
||||
"""
|
||||
@@ -1867,7 +1879,7 @@ def test_hierarchical_crew_creation_tasks_with_async_execution():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_crew_creation_tasks_with_sync_last():
|
||||
def test_hierarchical_crew_creation_tasks_with_sync_last(researcher, writer, ceo):
|
||||
"""
|
||||
Agents are not required for tasks in a hierarchical process but sometimes they are still added
|
||||
This test makes sure that the manager still delegates the task to the agent even if the agent is passed in the task
|
||||
@@ -2153,7 +2165,6 @@ def test_tools_with_custom_caching():
|
||||
with patch.object(
|
||||
CacheHandler, "add", wraps=crew._cache_handler.add
|
||||
) as add_to_cache:
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Check that add_to_cache was called exactly twice
|
||||
@@ -2170,7 +2181,7 @@ def test_tools_with_custom_caching():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_task_uses_last_output():
|
||||
def test_conditional_task_uses_last_output(researcher, writer):
|
||||
"""Test that conditional tasks use the last task output for condition evaluation."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
@@ -2244,7 +2255,7 @@ def test_conditional_task_uses_last_output():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_tasks_result_collection():
|
||||
def test_conditional_tasks_result_collection(researcher, writer):
|
||||
"""Test that task outputs are properly collected based on execution status."""
|
||||
task1 = Task(
|
||||
description="Normal task that always executes",
|
||||
@@ -2325,7 +2336,7 @@ def test_conditional_tasks_result_collection():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_multiple_conditional_tasks():
|
||||
def test_multiple_conditional_tasks(researcher, writer):
|
||||
"""Test that having multiple conditional tasks in sequence works correctly."""
|
||||
task1 = Task(
|
||||
description="Initial research task",
|
||||
@@ -2560,7 +2571,7 @@ def test_disabled_memory_using_contextual_memory():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_log_file_output(tmp_path):
|
||||
def test_crew_log_file_output(tmp_path, researcher):
|
||||
test_file = tmp_path / "logs.txt"
|
||||
tasks = [
|
||||
Task(
|
||||
@@ -2658,7 +2669,7 @@ def test_crew_output_file_validation_failures():
|
||||
Crew(agents=[agent], tasks=[task]).kickoff()
|
||||
|
||||
|
||||
def test_manager_agent():
|
||||
def test_manager_agent(researcher, writer):
|
||||
from unittest.mock import patch
|
||||
|
||||
task = Task(
|
||||
@@ -2696,7 +2707,7 @@ def test_manager_agent():
|
||||
mock_execute_sync.assert_called()
|
||||
|
||||
|
||||
def test_manager_agent_in_agents_raises_exception():
|
||||
def test_manager_agent_in_agents_raises_exception(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -2718,7 +2729,7 @@ def test_manager_agent_in_agents_raises_exception():
|
||||
)
|
||||
|
||||
|
||||
def test_manager_agent_with_tools_raises_exception():
|
||||
def test_manager_agent_with_tools_raises_exception(researcher, writer):
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
@@ -2755,7 +2766,7 @@ def test_manager_agent_with_tools_raises_exception():
|
||||
@patch("crewai.crew.TaskEvaluator")
|
||||
@patch("crewai.crew.Crew.copy")
|
||||
def test_crew_train_success(
|
||||
copy_mock, task_evaluator, crew_training_handler, kickoff_mock
|
||||
copy_mock, task_evaluator, crew_training_handler, kickoff_mock, researcher, writer
|
||||
):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
@@ -2831,7 +2842,7 @@ def test_crew_train_success(
|
||||
assert isinstance(received_events[1], CrewTrainCompletedEvent)
|
||||
|
||||
|
||||
def test_crew_train_error():
|
||||
def test_crew_train_error(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -2850,7 +2861,7 @@ def test_crew_train_error():
|
||||
)
|
||||
|
||||
|
||||
def test__setup_for_training():
|
||||
def test__setup_for_training(researcher, writer):
|
||||
researcher.allow_delegation = True
|
||||
writer.allow_delegation = True
|
||||
agents = [researcher, writer]
|
||||
@@ -2881,7 +2892,7 @@ def test__setup_for_training():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_replay_feature():
|
||||
def test_replay_feature(researcher, writer):
|
||||
list_ideas = Task(
|
||||
description="Generate a list of 5 interesting ideas to explore for an article, where each bulletpoint is under 15 words.",
|
||||
expected_output="Bullet point list of 5 important events. No additional commentary.",
|
||||
@@ -2918,7 +2929,7 @@ def test_replay_feature():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_replay_error():
|
||||
def test_crew_replay_error(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -3314,7 +3325,7 @@ def test_replay_setup_context():
|
||||
assert crew.tasks[1].prompt_context == "context raw output"
|
||||
|
||||
|
||||
def test_key():
|
||||
def test_key(researcher, writer):
|
||||
tasks = [
|
||||
Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
|
||||
@@ -3383,7 +3394,9 @@ def test_key_with_interpolated_inputs():
|
||||
assert crew.key == curr_key
|
||||
|
||||
|
||||
def test_conditional_task_requirement_breaks_when_singular_conditional_task():
|
||||
def test_conditional_task_requirement_breaks_when_singular_conditional_task(
|
||||
researcher, writer
|
||||
):
|
||||
def condition_fn(output) -> bool:
|
||||
return output.raw.startswith("Andrew Ng has!!")
|
||||
|
||||
@@ -3401,7 +3414,7 @@ def test_conditional_task_requirement_breaks_when_singular_conditional_task():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_task_last_task_when_conditional_is_true():
|
||||
def test_conditional_task_last_task_when_conditional_is_true(researcher, writer):
|
||||
def condition_fn(output) -> bool:
|
||||
return True
|
||||
|
||||
@@ -3428,7 +3441,7 @@ def test_conditional_task_last_task_when_conditional_is_true():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_task_last_task_when_conditional_is_false():
|
||||
def test_conditional_task_last_task_when_conditional_is_false(researcher, writer):
|
||||
def condition_fn(output) -> bool:
|
||||
return False
|
||||
|
||||
@@ -3452,7 +3465,7 @@ def test_conditional_task_last_task_when_conditional_is_false():
|
||||
assert result.raw == "Hi"
|
||||
|
||||
|
||||
def test_conditional_task_requirement_breaks_when_task_async():
|
||||
def test_conditional_task_requirement_breaks_when_task_async(researcher, writer):
|
||||
def my_condition(context):
|
||||
return context.get("some_value") > 10
|
||||
|
||||
@@ -3477,7 +3490,7 @@ def test_conditional_task_requirement_breaks_when_task_async():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_should_skip():
|
||||
def test_conditional_should_skip(researcher, writer):
|
||||
task1 = Task(description="Return hello", expected_output="say hi", agent=researcher)
|
||||
|
||||
condition_mock = MagicMock(return_value=False)
|
||||
@@ -3509,7 +3522,7 @@ def test_conditional_should_skip():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_should_execute():
|
||||
def test_conditional_should_execute(researcher, writer):
|
||||
task1 = Task(description="Return hello", expected_output="say hi", agent=researcher)
|
||||
|
||||
condition_mock = MagicMock(
|
||||
@@ -3542,7 +3555,7 @@ def test_conditional_should_execute():
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator, researcher):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -3592,7 +3605,7 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_verbose_manager_agent():
|
||||
def test_hierarchical_verbose_manager_agent(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -3613,7 +3626,7 @@ def test_hierarchical_verbose_manager_agent():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_verbose_false_manager_agent():
|
||||
def test_hierarchical_verbose_false_manager_agent(researcher, writer):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -4186,7 +4199,7 @@ def test_before_kickoff_without_inputs():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_knowledge_sources_works_with_copy():
|
||||
def test_crew_with_knowledge_sources_works_with_copy(researcher, writer):
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
@@ -4195,7 +4208,6 @@ def test_crew_with_knowledge_sources_works_with_copy():
|
||||
tasks=[Task(description="test", expected_output="test", agent=researcher)],
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert crew_copy.knowledge_sources == crew.knowledge_sources
|
||||
@@ -4339,3 +4351,35 @@ def test_crew_copy_with_memory():
|
||||
raise e # Re-raise other validation errors
|
||||
except Exception as e:
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
assert crew.parent_flow is None
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
return Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(
|
||||
description="Task 1", expected_output="output", agent=researcher
|
||||
),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
|
||||
flow = MyFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.parent_flow is flow
|
||||
|
||||
@@ -547,6 +547,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_docling_source(mock_vector_db):
|
||||
docling_source = CrewDoclingSource(
|
||||
file_paths=[
|
||||
@@ -567,6 +568,7 @@ def test_docling_source(mock_vector_db):
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_multiple_docling_sources():
|
||||
urls: List[Union[Path, str]] = [
|
||||
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||
|
||||
155
tests/llm_response_cache_test.py
Normal file
155
tests/llm_response_cache_test.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
handler = LLMResponseCacheHandler()
|
||||
handler.storage.add = MagicMock()
|
||||
handler.storage.get = MagicMock()
|
||||
return handler
|
||||
|
||||
|
||||
def create_mock_response(content):
|
||||
"""Create a properly structured mock response object for litellm.completion"""
|
||||
message = SimpleNamespace(content=content)
|
||||
choice = SimpleNamespace(message=message)
|
||||
response = SimpleNamespace(choices=[choice])
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_recording(handler):
|
||||
handler.start_recording()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
handler.storage.add.assert_called_once_with(
|
||||
"gpt-4o-mini", messages, "Hello, human!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replaying(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = "Cached response"
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Cached response"
|
||||
|
||||
mock_completion.assert_not_called()
|
||||
|
||||
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replay_fallback(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = None
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_error_handling():
|
||||
"""Test that errors during cache operations are handled gracefully."""
|
||||
handler = LLMResponseCacheHandler()
|
||||
|
||||
handler.storage.add = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
handler.storage.get = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
|
||||
handler.start_recording()
|
||||
|
||||
handler.cache_response("model", [{"role": "user", "content": "test"}], "response")
|
||||
|
||||
handler.start_replaying()
|
||||
|
||||
assert handler.get_cached_response("model", [{"role": "user", "content": "test"}]) is None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_expiration():
|
||||
"""Test that cache expiration works correctly."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
storage = LLMResponseCacheStorage(":memory:")
|
||||
|
||||
original_get_connection = storage._get_connection
|
||||
storage._get_connection = lambda: conn
|
||||
|
||||
try:
|
||||
model = "test-model"
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
response = "test response"
|
||||
storage.add(model, messages, response)
|
||||
|
||||
assert storage.get(model, messages) == response
|
||||
|
||||
storage.cleanup_expired_cache(max_age_days=0)
|
||||
|
||||
assert storage.get(model, messages) is None
|
||||
finally:
|
||||
storage._get_connection = original_get_connection
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_concurrent_cache_access():
|
||||
"""Test that concurrent cache access works correctly."""
|
||||
pytest.skip("SQLite in-memory databases are not shared between threads")
|
||||
|
||||
|
||||
# storage = LLMResponseCacheStorage(temp_db.name)
|
||||
93
tests/record_replay_test.py
Normal file
93
tests/record_replay_test.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_recording_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the recording functionality",
|
||||
backstory="A test agent for recording LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_recording.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_replay_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the replay functionality",
|
||||
backstory="A test agent for replaying LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
replay_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_replaying.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_record_replay_flags_conflict():
|
||||
with pytest.raises(ValueError):
|
||||
crew = Crew(
|
||||
agents=[],
|
||||
tasks=[],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
replay_mode=True,
|
||||
)
|
||||
crew.kickoff()
|
||||
69
tests/telemetry/test_telemetry.py
Normal file
69
tests/telemetry/test_telemetry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var,value,expected_ready",
|
||||
[
|
||||
("OTEL_SDK_DISABLED", "true", False),
|
||||
("OTEL_SDK_DISABLED", "TRUE", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "true", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "TRUE", False),
|
||||
("OTEL_SDK_DISABLED", "false", True),
|
||||
("CREWAI_DISABLE_TELEMETRY", "false", True),
|
||||
],
|
||||
)
|
||||
def test_telemetry_environment_variables(env_var, value, expected_ready):
|
||||
"""Test telemetry state with different environment variable configurations."""
|
||||
with patch.dict(os.environ, {env_var: value}):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is expected_ready
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default():
|
||||
"""Test that telemetry is enabled by default."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is True
|
||||
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
|
||||
@patch("crewai.telemetry.telemetry.logger.error")
|
||||
@patch(
|
||||
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter.export",
|
||||
side_effect=Exception("Test exception"),
|
||||
)
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock):
|
||||
error = Exception("Test exception")
|
||||
export_mock.side_effect = error
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
with tracer.start_as_current_span("test-span"):
|
||||
agent = Agent(
|
||||
role="agent",
|
||||
llm="gpt-4o-mini",
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
export_mock.assert_called_once()
|
||||
logger_mock.assert_called_once_with(error)
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
||||
|
||||
|
||||
@@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async():
|
||||
assert "21 million" in result.raw or "37 million" in result.raw
|
||||
assert result.usage_metrics is not None
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
class TestFlow(Flow):
|
||||
"""A test flow that creates and runs an agent."""
|
||||
|
||||
def __init__(self, llm, tools):
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
|
||||
def verify_agent_parent_flow(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct parent flow."""
|
||||
assert result.parent_flow is flow
|
||||
assert agent is not None
|
||||
assert agent.parent_flow is flow
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow():
|
||||
captured_agent = None
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=mock_llm,
|
||||
tools=[WebSearchTool()],
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
|
||||
result = flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
@@ -32,3 +32,16 @@ def test_wildcard_event_handler():
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
mock_handler.assert_called_once_with("source_object", event)
|
||||
|
||||
|
||||
def test_event_bus_error_handling(capfd):
|
||||
@crewai_event_bus.on(BaseEvent)
|
||||
def broken_handler(source, event):
|
||||
raise ValueError("Simulated handler failure")
|
||||
|
||||
event = TestEvent(type="test_event")
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
out, err = capfd.readouterr()
|
||||
assert "Simulated handler failure" in out
|
||||
assert "Handler 'broken_handler' failed" in out
|
||||
|
||||
36
uv.lock
generated
36
uv.lock
generated
@@ -811,8 +811,10 @@ dev = [
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-randomly" },
|
||||
{ name = "pytest-recording" },
|
||||
{ name = "pytest-subprocess" },
|
||||
{ name = "pytest-timeout" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
@@ -833,7 +835,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
{ name = "json5", specifier = ">=0.10.0" },
|
||||
{ name = "jsonref", specifier = ">=1.1.0" },
|
||||
{ name = "litellm", specifier = "==1.67.1" },
|
||||
{ name = "litellm", specifier = "==1.68.0" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.94" },
|
||||
{ name = "openai", specifier = ">=1.13.3" },
|
||||
{ name = "openpyxl", specifier = ">=3.1.5" },
|
||||
@@ -867,8 +869,10 @@ dev = [
|
||||
{ name = "pre-commit", specifier = ">=3.6.0" },
|
||||
{ name = "pytest", specifier = ">=8.0.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.23.7" },
|
||||
{ name = "pytest-randomly", specifier = ">=3.16.0" },
|
||||
{ name = "pytest-recording", specifier = ">=0.13.2" },
|
||||
{ name = "pytest-subprocess", specifier = ">=1.5.2" },
|
||||
{ name = "pytest-timeout", specifier = ">=2.3.1" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "ruff", specifier = ">=0.8.2" },
|
||||
]
|
||||
@@ -2383,7 +2387,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.67.1"
|
||||
version = "1.68.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@@ -2398,9 +2402,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/54/a4/bb3e9ae59e5a9857443448de7c04752630dc84cddcbd8cee037c0976f44f/litellm-1.67.1.tar.gz", hash = "sha256:78eab1bd3d759ec13aa4a05864356a4a4725634e78501db609d451bf72150ee7", size = 7242044 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/22/138545b646303ca3f4841b69613c697b9d696322a1386083bb70bcbba60b/litellm-1.68.0.tar.gz", hash = "sha256:9fb24643db84dfda339b64bafca505a2eef857477afbc6e98fb56512c24dbbfa", size = 7314051 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/86/c14d3c24ae13c08296d068e6f79fd4bd17a0a07bddbda94990b87c35d20e/litellm-1.67.1-py3-none-any.whl", hash = "sha256:8fff5b2a16b63bb594b94d6c071ad0f27d3d8cd4348bd5acea2fd40c8e0c11e8", size = 7607266 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/af/1e344bc8aee41445272e677d802b774b1f8b34bdc3bb5697ba30f0fb5d52/litellm-1.68.0-py3-none-any.whl", hash = "sha256:3bca38848b1a5236b11aa6b70afa4393b60880198c939e582273f51a542d4759", size = 7684460 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4228,6 +4232,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-randomly"
|
||||
version = "3.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/68/d221ed7f4a2a49a664da721b8e87b52af6dd317af2a6cb51549cf17ac4b8/pytest_randomly-3.16.0.tar.gz", hash = "sha256:11bf4d23a26484de7860d82f726c0629837cf4064b79157bd18ec9d41d7feb26", size = 13367 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/22/70/b31577d7c46d8e2f9baccfed5067dd8475262a2331ffb0bfdf19361c9bde/pytest_randomly-3.16.0-py3-none-any.whl", hash = "sha256:8633d332635a1a0983d3bba19342196807f6afb17c3eef78e02c2f85dade45d6", size = 8396 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-recording"
|
||||
version = "0.13.2"
|
||||
@@ -4254,6 +4270,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/77/a80e8f9126b95ffd5ad4d04bd14005c68dcbf0d88f53b2b14893f6cc7232/pytest_subprocess-1.5.2-py3-none-any.whl", hash = "sha256:23ac7732aa8bd45f1757265b1316eb72a7f55b41fb21e2ca22e149ba3629fa46", size = 20886 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-timeout"
|
||||
version = "2.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/93/0d/04719abc7a4bdb3a7a1f968f24b0f5253d698c9cc94975330e9d3145befb/pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9", size = 17697 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/27/14af9ef8321f5edc7527e47def2a21d8118c6f329a9342cc61387a0c0599/pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e", size = 14148 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-bidi"
|
||||
version = "0.6.3"
|
||||
|
||||
Reference in New Issue
Block a user