mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-26 04:42:41 +00:00
Compare commits
12 Commits
devin/1746
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79547fba25 | ||
|
|
171f8b63fd | ||
|
|
72df165b07 | ||
|
|
63eccf5e30 | ||
|
|
a98a44afb2 | ||
|
|
6e0f1fe38d | ||
|
|
c2bf2b3210 | ||
|
|
14579a7861 | ||
|
|
dabf02a90d | ||
|
|
2912c93d77 | ||
|
|
17474a3a0c | ||
|
|
f89c2bfb7e |
92
manual_test_csv_update.py
Normal file
92
manual_test_csv_update.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Manual test script to verify CSV knowledge source update functionality.
|
||||
This script creates a CSV file, loads it as a knowledge source, updates the file,
|
||||
and verifies that the updated content is detected and loaded.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
|
||||
|
||||
def test_csv_knowledge_source_updates():
|
||||
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
csv_path = Path(tmpdir) / "test_updates.csv"
|
||||
|
||||
initial_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "New York"],
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
]
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in initial_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
print(f"Created CSV file at {csv_path}")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
|
||||
if not hasattr(csv_source, 'files_have_changed'):
|
||||
print("❌ TEST FAILED: files_have_changed method not found in CSVKnowledgeSource")
|
||||
return False
|
||||
|
||||
if not hasattr(csv_source, '_file_mtimes'):
|
||||
print("❌ TEST FAILED: _file_mtimes attribute not found in CSVKnowledgeSource")
|
||||
return False
|
||||
|
||||
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
|
||||
|
||||
if not hasattr(knowledge, '_check_and_reload_sources'):
|
||||
print("❌ TEST FAILED: _check_and_reload_sources method not found in Knowledge")
|
||||
return False
|
||||
|
||||
print("✅ All required methods and attributes exist")
|
||||
|
||||
updated_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "Boston"], # Changed city
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
["Eve", "22", "Miami"], # Added new person
|
||||
]
|
||||
|
||||
print("\nWaiting for 1 second before updating file...")
|
||||
time.sleep(1)
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in updated_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
print(f"Updated CSV file at {csv_path}")
|
||||
|
||||
if not csv_source.files_have_changed():
|
||||
print("❌ TEST FAILED: files_have_changed did not detect file modification")
|
||||
return False
|
||||
|
||||
print("✅ files_have_changed correctly detected file modification")
|
||||
|
||||
csv_source._record_file_mtimes()
|
||||
csv_source.content = csv_source.load_content()
|
||||
|
||||
content_str = str(csv_source.content)
|
||||
if "Boston" in content_str and "Eve" in content_str and "Miami" in content_str:
|
||||
print("✅ Content was correctly updated with new data")
|
||||
else:
|
||||
print("❌ TEST FAILED: Content was not updated with new data")
|
||||
return False
|
||||
|
||||
print("\n✅ TEST PASSED: CSV knowledge source correctly detects and loads file updates")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_csv_knowledge_source_updates()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.67.1",
|
||||
"litellm==1.68.0",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
|
||||
@@ -2,7 +2,7 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crew
|
||||
from crewai.cli.utils import get_crews
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -26,35 +26,47 @@ def reset_memories_command(
|
||||
"""
|
||||
|
||||
try:
|
||||
crew = get_crew()
|
||||
if not crew:
|
||||
raise ValueError("No crew found.")
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo("All memories have been reset.")
|
||||
return
|
||||
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
return
|
||||
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo("Long term memory has been reset.")
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo("Knowledge has been reset.")
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
|
||||
)
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -2,7 +2,8 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List
|
||||
from inspect import isfunction, ismethod
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -10,6 +11,7 @@ from rich.console import Console
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -250,11 +252,11 @@ def write_env_file(folder_path, env_vars):
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
"""Get the crew instance from the crew.py file."""
|
||||
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
"""Get the crew instances from the a file."""
|
||||
crew_instances = []
|
||||
try:
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
if crew_path in files:
|
||||
@@ -271,12 +273,10 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
try:
|
||||
if callable(attr) and hasattr(attr, "crew"):
|
||||
crew_instance = attr().crew()
|
||||
return crew_instance
|
||||
module_attr = getattr(module, attr_name)
|
||||
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
continue
|
||||
@@ -286,7 +286,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
@@ -300,7 +299,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
if require:
|
||||
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if require:
|
||||
@@ -308,4 +306,36 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_crew_instance(module_attr) -> Crew | None:
|
||||
if (
|
||||
callable(module_attr)
|
||||
and hasattr(module_attr, "is_crew_class")
|
||||
and module_attr.is_crew_class
|
||||
):
|
||||
return module_attr().crew()
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
elif isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
crew_instances: list[Crew] = []
|
||||
|
||||
if crew_instance := get_crew_instance(module_attr):
|
||||
crew_instances.append(crew_instance)
|
||||
|
||||
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
|
||||
instance = module_attr()
|
||||
for attr_name in dir(instance):
|
||||
attr = getattr(instance, attr_name)
|
||||
if crew_instance := get_crew_instance(attr):
|
||||
crew_instances.append(crew_instance)
|
||||
return crew_instances
|
||||
|
||||
@@ -6,7 +6,17 @@ import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -24,6 +34,7 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
@@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
class Crew(BaseModel):
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||
|
||||
@@ -304,7 +315,9 @@ class Crew(BaseModel):
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
self.external_memory.set_crew(self)
|
||||
if self.external_memory
|
||||
else None
|
||||
)
|
||||
|
||||
self._long_term_memory = self.long_term_memory
|
||||
@@ -333,6 +346,7 @@ class Crew(BaseModel):
|
||||
embedder=self.embedder,
|
||||
collection_name="crew",
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
@@ -1369,8 +1383,6 @@ class Crew(BaseModel):
|
||||
else:
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
self._logger.log("info", f"{command_type} memory has been reset")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
@@ -1391,8 +1403,14 @@ class Crew(BaseModel):
|
||||
if system is not None:
|
||||
try:
|
||||
system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1421,5 +1439,11 @@ class Crew(BaseModel):
|
||||
|
||||
try:
|
||||
memory_system.reset()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
44
src/crewai/flow/flow_trackable.py
Normal file
44
src/crewai/flow/flow_trackable.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
class FlowTrackable(BaseModel):
|
||||
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
|
||||
Flow instance that created a Crew or Agent.
|
||||
|
||||
Automatically finds and stores a reference to the parent Flow instance by
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
|
||||
frame = inspect.currentframe()
|
||||
|
||||
try:
|
||||
if frame is None:
|
||||
return self
|
||||
|
||||
frame = frame.f_back
|
||||
for _ in range(max_depth):
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
candidate = frame.f_locals.get("self")
|
||||
if isinstance(candidate, Flow):
|
||||
self.parent_flow = candidate
|
||||
break
|
||||
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame
|
||||
|
||||
return self
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
@@ -12,10 +13,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
|
||||
This class manages knowledge sources and provides methods to query them for relevant information.
|
||||
It automatically detects and reloads file-based knowledge sources when their underlying files change.
|
||||
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
The knowledge sources to use for querying.
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
The storage backend for knowledge embeddings.
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
Configuration for the embedding model.
|
||||
collection_name: Optional[str] = None
|
||||
Name of the collection to use for storage.
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
@@ -23,6 +33,7 @@ class Knowledge(BaseModel):
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,7 +52,6 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
@@ -55,6 +65,8 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
self._check_and_reload_sources()
|
||||
|
||||
results = self.storage.search(
|
||||
query,
|
||||
@@ -62,8 +74,67 @@ class Knowledge(BaseModel):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def _check_and_reload_sources(self):
|
||||
"""
|
||||
Check if any file-based knowledge sources have changed and reload them if necessary.
|
||||
|
||||
This method detects modifications to source files by comparing their modification timestamps
|
||||
with previously recorded values. When changes are detected, the source is reloaded and
|
||||
the storage is updated with the new content.
|
||||
|
||||
The method handles various file-related exceptions with specific error messages:
|
||||
- FileNotFoundError: When a source file no longer exists
|
||||
- PermissionError: When there are permission issues accessing a file
|
||||
- IOError: When there are I/O errors reading a file
|
||||
- ValueError: When there are issues with file content format
|
||||
- Other unexpected exceptions are also caught and logged
|
||||
|
||||
Each exception is logged with appropriate context to aid in troubleshooting.
|
||||
"""
|
||||
for source in self.sources:
|
||||
source_name = source.__class__.__name__
|
||||
try:
|
||||
if hasattr(source, 'files_have_changed') and source.files_have_changed():
|
||||
self._logger.log("info", f"Reloading modified source: {source_name}")
|
||||
|
||||
# Update file modification timestamps
|
||||
try:
|
||||
source._record_file_mtimes()
|
||||
except (PermissionError, IOError) as e:
|
||||
self._logger.log("warning", f"Could not record file timestamps for {source_name}: {str(e)}")
|
||||
|
||||
try:
|
||||
source.content = source.load_content()
|
||||
except FileNotFoundError as e:
|
||||
self._logger.log("error", f"File not found when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except ValueError as e:
|
||||
self._logger.log("error", f"Invalid content format in {source_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
source.add()
|
||||
self._logger.log("info", f"Successfully reloaded and updated {source_name}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to update storage for {source_name}: {str(e)}")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
self._logger.log("error", f"File not found when checking for updates in {source_name}: {str(e)}")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking for updates in {source_name}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking for updates in {source_name}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking for updates in {source_name}: {str(e)}")
|
||||
|
||||
def _add_sources(self):
|
||||
def add_sources(self):
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -11,9 +12,24 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
"""
|
||||
Base class for knowledge sources that load content from files.
|
||||
|
||||
This class provides common functionality for file-based knowledge sources,
|
||||
including file path validation, content loading, and change detection.
|
||||
It automatically tracks file modification times to detect when files have
|
||||
been updated and need to be reloaded.
|
||||
|
||||
Attributes:
|
||||
file_path: Deprecated. Use file_paths instead.
|
||||
file_paths: Path(s) to the file(s) containing knowledge data.
|
||||
content: Dictionary mapping file paths to their loaded content.
|
||||
storage: Storage backend for the knowledge data.
|
||||
safe_file_paths: Validated list of Path objects.
|
||||
"""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
@@ -43,7 +59,34 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
self._record_file_mtimes()
|
||||
self.content = self.load_content()
|
||||
|
||||
def _record_file_mtimes(self):
|
||||
"""
|
||||
Record modification times of all files.
|
||||
|
||||
This method stores the current modification timestamps of all files
|
||||
in the _file_mtimes dictionary. These timestamps are later used to
|
||||
detect when files have been modified and need to be reloaded.
|
||||
|
||||
Thread-safe: Uses a lock to prevent concurrent modifications.
|
||||
"""
|
||||
with self._lock:
|
||||
self._file_mtimes = {}
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if path.exists() and path.is_file():
|
||||
if os.access(path, os.R_OK):
|
||||
self._file_mtimes[path] = path.stat().st_mtime
|
||||
else:
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when recording file timestamp for {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when recording file timestamp for {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when recording file timestamp for {path}: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
@@ -107,3 +150,41 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def files_have_changed(self) -> bool:
|
||||
"""
|
||||
Check if any of the files have been modified since they were last loaded.
|
||||
|
||||
This method compares the current modification timestamps of files with the
|
||||
previously recorded timestamps to detect changes. When a file has been modified,
|
||||
it logs the change and returns True to trigger a reload.
|
||||
|
||||
Returns:
|
||||
bool: True if any file has been modified, False otherwise.
|
||||
"""
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if not path.exists():
|
||||
self._logger.log("warning", f"File {path} no longer exists.")
|
||||
continue
|
||||
|
||||
if not path.is_file():
|
||||
self._logger.log("warning", f"Path {path} is not a file.")
|
||||
continue
|
||||
|
||||
if not os.access(path, os.R_OK):
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
continue
|
||||
|
||||
current_mtime = path.stat().st_mtime
|
||||
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
|
||||
self._logger.log("info", f"File {path} has been modified. Reloading data.")
|
||||
return True
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking file {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking file {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking file {path}: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel):
|
||||
return self.raw
|
||||
|
||||
|
||||
class LiteAgent(BaseModel):
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
|
||||
@@ -162,7 +163,7 @@ class LiteAgent(BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
|
||||
@@ -351,6 +351,7 @@ class LLM(BaseLLM):
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
@@ -368,9 +369,6 @@ class LLM(BaseLLM):
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
if self.stop and self.supports_stop_words():
|
||||
params["stop"] = self.stop
|
||||
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
@@ -14,6 +15,8 @@ from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -28,7 +31,10 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
|
||||
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import ( # noqa: E402
|
||||
BatchSpanProcessor,
|
||||
SpanExportResult,
|
||||
)
|
||||
from opentelemetry.trace import Span, Status, StatusCode # noqa: E402
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,6 +42,15 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
def export(self, spans) -> SpanExportResult:
|
||||
try:
|
||||
return super().export(spans)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return SpanExportResult.FAILURE
|
||||
|
||||
|
||||
class Telemetry:
|
||||
"""A class to handle anonymous telemetry for the crewai package.
|
||||
|
||||
@@ -64,7 +79,7 @@ class Telemetry:
|
||||
self.provider = TracerProvider(resource=self.resource)
|
||||
|
||||
processor = BatchSpanProcessor(
|
||||
OTLPSpanExporter(
|
||||
SafeOTLPSpanExporter(
|
||||
endpoint=f"{CREWAI_TELEMETRY_BASE_URL}/v1/traces",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
221
tests/cassettes/test_telemetry_fails_due_connect_timeout.yaml
Normal file
221
tests/cassettes/test_telemetry_fails_due_connect_timeout.yaml
Normal file
File diff suppressed because one or more lines are too long
@@ -18,6 +18,7 @@ from crewai.cli.cli import (
|
||||
train,
|
||||
version,
|
||||
)
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -55,81 +56,133 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_all_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
_mock = mock.Mock(spec=Crew, name="test_crew")
|
||||
_mock.name = "test_crew"
|
||||
return _mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew:
|
||||
yield mock_get_crew
|
||||
|
||||
|
||||
def test_reset_all_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert result.output == "All memories have been reset.\n"
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Reset memories command has been completed."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_short_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_short_term_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Short term memory has been reset." in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert result.output == "Short term memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_entity_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_entity_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-e"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert result.output == "Entity memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_long_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_long_term_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert result.output == "Long term memory has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_kickoff_outputs(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-k"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Latest Kickoff outputs stored has been reset."
|
||||
in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_multiple_memory_flags(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_multiple_memory_flags(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Long term memory has been reset.\n"
|
||||
f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
# Check that reset_memories was called twice with the correct arguments
|
||||
assert mock_crew.reset_memories.call_count == 2
|
||||
mock_crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
result.output
|
||||
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||
)
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_knowledge(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
def test_reset_knowledge(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.name})] Knowledge has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
|
||||
crews = []
|
||||
for crew_id in ["id-1234", "id-5678"]:
|
||||
mock_crew = mock.Mock(spec=Crew)
|
||||
mock_crew.name = None
|
||||
mock_crew.id = crew_id
|
||||
crews.append(mock_crew)
|
||||
|
||||
mock_get_crews.return_value = crews
|
||||
|
||||
# Run the command
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
call_count = 0
|
||||
for crew in crews:
|
||||
call_count += 1
|
||||
crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert f"[Crew ({crew.id})] Knowledge has been reset." in result.output
|
||||
|
||||
assert call_count == 2, "reset_memories should have been called twice"
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
|
||||
@@ -3,12 +3,13 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
@@ -23,17 +24,20 @@ def in_temp_dir():
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command = ToolCommand()
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
TokenManager().save_tokens("test-token", 36000)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
with (
|
||||
patch.object(tool_command, "login") as mock_login,
|
||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
||||
):
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
with in_temp_dir():
|
||||
tool_command.create("test-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
@@ -47,15 +51,12 @@ def test_create_success(mock_subprocess):
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(mock_get, mock_subprocess_run):
|
||||
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
@@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
@@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get):
|
||||
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("non-existent-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get):
|
||||
def test_install_api_error(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("error-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit):
|
||||
tool_command = ToolCommand()
|
||||
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
||||
output = capsys.readouterr().out
|
||||
assert "Local changes need to be resolved before publishing" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -205,13 +192,13 @@ def test_publish_success(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -251,25 +238,22 @@ def test_publish_failure(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@@ -290,6 +274,8 @@ def test_publish_api_error(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -297,14 +283,9 @@ def test_publish_api_error(
|
||||
mock_response.ok = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow import Flow, listen, start
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
@@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching():
|
||||
with patch.object(
|
||||
CacheHandler, "add", wraps=crew._cache_handler.add
|
||||
) as add_to_cache:
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Check that add_to_cache was called exactly twice
|
||||
@@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory():
|
||||
raise e # Re-raise other validation errors
|
||||
except Exception as e:
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
assert crew.parent_flow is None
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
return Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(
|
||||
description="Task 1", expected_output="output", agent=researcher
|
||||
),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
|
||||
flow = MyFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.parent_flow is flow
|
||||
|
||||
85
tests/knowledge/test_csv_knowledge_source_updates.py
Normal file
85
tests/knowledge/test_csv_knowledge_source_updates.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.search')
|
||||
@patch('crewai.knowledge.source.csv_knowledge_source.CSVKnowledgeSource.add')
|
||||
def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
|
||||
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
|
||||
mock_search.side_effect = [
|
||||
[{"context": "name,age,city\nJohn,30,New York\nAlice,25,San Francisco\nBob,28,Chicago"}],
|
||||
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}],
|
||||
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}]
|
||||
]
|
||||
|
||||
csv_path = str(tmpdir / "test_updates.csv")
|
||||
|
||||
initial_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "New York"],
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
]
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in initial_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
|
||||
original_files_have_changed = csv_source.files_have_changed
|
||||
files_changed_called = [False]
|
||||
|
||||
def spy_files_have_changed():
|
||||
files_changed_called[0] = True
|
||||
return original_files_have_changed()
|
||||
|
||||
csv_source.files_have_changed = spy_files_have_changed
|
||||
|
||||
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
|
||||
|
||||
assert hasattr(knowledge, '_check_and_reload_sources'), "Knowledge class is missing _check_and_reload_sources method"
|
||||
|
||||
initial_results = knowledge.query(["John"])
|
||||
assert any("John" in result["context"] for result in initial_results)
|
||||
assert any("New York" in result["context"] for result in initial_results)
|
||||
|
||||
mock_add.reset_mock()
|
||||
files_changed_called[0] = False
|
||||
|
||||
updated_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "Boston"], # Changed city
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
["Eve", "22", "Miami"], # Added new person
|
||||
]
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
csv_path_str = str(csv_path)
|
||||
with open(csv_path_str, "w") as f:
|
||||
for row in updated_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
updated_results = knowledge.query(["John"])
|
||||
|
||||
assert files_changed_called[0], "files_have_changed method was not called during query"
|
||||
|
||||
assert mock_add.called, "add method was not called to reload the data"
|
||||
|
||||
assert any("John" in result["context"] for result in updated_results)
|
||||
assert any("Boston" in result["context"] for result in updated_results)
|
||||
assert not any("New York" in result["context"] for result in updated_results)
|
||||
|
||||
new_results = knowledge.query(["Eve"])
|
||||
assert any("Eve" in result["context"] for result in new_results)
|
||||
assert any("Miami" in result["context"] for result in new_results)
|
||||
69
tests/telemetry/test_telemetry.py
Normal file
69
tests/telemetry/test_telemetry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var,value,expected_ready",
|
||||
[
|
||||
("OTEL_SDK_DISABLED", "true", False),
|
||||
("OTEL_SDK_DISABLED", "TRUE", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "true", False),
|
||||
("CREWAI_DISABLE_TELEMETRY", "TRUE", False),
|
||||
("OTEL_SDK_DISABLED", "false", True),
|
||||
("CREWAI_DISABLE_TELEMETRY", "false", True),
|
||||
],
|
||||
)
|
||||
def test_telemetry_environment_variables(env_var, value, expected_ready):
|
||||
"""Test telemetry state with different environment variable configurations."""
|
||||
with patch.dict(os.environ, {env_var: value}):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is expected_ready
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default():
|
||||
"""Test that telemetry is enabled by default."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||
telemetry = Telemetry()
|
||||
assert telemetry.ready is True
|
||||
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
|
||||
@patch("crewai.telemetry.telemetry.logger.error")
|
||||
@patch(
|
||||
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter.export",
|
||||
side_effect=Exception("Test exception"),
|
||||
)
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock):
|
||||
error = Exception("Test exception")
|
||||
export_mock.side_effect = error
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
with tracer.start_as_current_span("test-span"):
|
||||
agent = Agent(
|
||||
role="agent",
|
||||
llm="gpt-4o-mini",
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
export_mock.assert_called_once()
|
||||
logger_mock.assert_called_once_with(error)
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
||||
|
||||
|
||||
@@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async():
|
||||
assert "21 million" in result.raw or "37 million" in result.raw
|
||||
assert result.usage_metrics is not None
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
class TestFlow(Flow):
|
||||
"""A test flow that creates and runs an agent."""
|
||||
|
||||
def __init__(self, llm, tools):
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
|
||||
def verify_agent_parent_flow(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct parent flow."""
|
||||
assert result.parent_flow is flow
|
||||
assert agent is not None
|
||||
assert agent.parent_flow is flow
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow():
|
||||
captured_agent = None
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=mock_llm,
|
||||
tools=[WebSearchTool()],
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
|
||||
result = flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
class TestLLM(unittest.TestCase):
|
||||
@patch("crewai.llm.litellm.completion")
|
||||
@patch("crewai.llm.LLM.supports_stop_words")
|
||||
def test_call_with_supported_stop_words(self, mock_supports_stop_words, mock_completion):
|
||||
mock_supports_stop_words.return_value = True
|
||||
|
||||
message = SimpleNamespace(content="Hello, World!")
|
||||
choice = SimpleNamespace(message=message)
|
||||
response = SimpleNamespace(choices=[choice])
|
||||
mock_completion.return_value = response
|
||||
|
||||
llm = LLM(model="gpt-4", stop=["STOP"])
|
||||
|
||||
messages = [{"role": "user", "content": "Say Hello"}]
|
||||
result = llm.call(messages)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
call_args = mock_completion.call_args[1]
|
||||
self.assertIn("stop", call_args)
|
||||
self.assertEqual(call_args["stop"], ["STOP"])
|
||||
self.assertEqual(result, "Hello, World!")
|
||||
|
||||
@patch("crewai.llm.litellm.completion")
|
||||
@patch("crewai.llm.LLM.supports_stop_words")
|
||||
def test_call_with_unsupported_stop_words(self, mock_supports_stop_words, mock_completion):
|
||||
mock_supports_stop_words.return_value = False
|
||||
|
||||
message = SimpleNamespace(content="Hello, World!")
|
||||
choice = SimpleNamespace(message=message)
|
||||
response = SimpleNamespace(choices=[choice])
|
||||
mock_completion.return_value = response
|
||||
|
||||
llm = LLM(model="o3", stop=["STOP"])
|
||||
|
||||
messages = [{"role": "user", "content": "Say Hello"}]
|
||||
result = llm.call(messages)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
call_args = mock_completion.call_args[1]
|
||||
self.assertNotIn("stop", call_args)
|
||||
self.assertEqual(result, "Hello, World!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -835,7 +835,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
{ name = "json5", specifier = ">=0.10.0" },
|
||||
{ name = "jsonref", specifier = ">=1.1.0" },
|
||||
{ name = "litellm", specifier = "==1.67.1" },
|
||||
{ name = "litellm", specifier = "==1.68.0" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.94" },
|
||||
{ name = "openai", specifier = ">=1.13.3" },
|
||||
{ name = "openpyxl", specifier = ">=3.1.5" },
|
||||
@@ -2387,7 +2387,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.67.1"
|
||||
version = "1.68.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@@ -2402,9 +2402,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/54/a4/bb3e9ae59e5a9857443448de7c04752630dc84cddcbd8cee037c0976f44f/litellm-1.67.1.tar.gz", hash = "sha256:78eab1bd3d759ec13aa4a05864356a4a4725634e78501db609d451bf72150ee7", size = 7242044 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/22/138545b646303ca3f4841b69613c697b9d696322a1386083bb70bcbba60b/litellm-1.68.0.tar.gz", hash = "sha256:9fb24643db84dfda339b64bafca505a2eef857477afbc6e98fb56512c24dbbfa", size = 7314051 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/86/c14d3c24ae13c08296d068e6f79fd4bd17a0a07bddbda94990b87c35d20e/litellm-1.67.1-py3-none-any.whl", hash = "sha256:8fff5b2a16b63bb594b94d6c071ad0f27d3d8cd4348bd5acea2fd40c8e0c11e8", size = 7607266 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/af/1e344bc8aee41445272e677d802b774b1f8b34bdc3bb5697ba30f0fb5d52/litellm-1.68.0-py3-none-any.whl", hash = "sha256:3bca38848b1a5236b11aa6b70afa4393b60880198c939e582273f51a542d4759", size = 7684460 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user