mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
1 Commits
devin/1739
...
devin/1744
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4748597667 |
@@ -1076,50 +1076,18 @@ class Crew(BaseModel):
|
||||
self,
|
||||
n_iterations: int,
|
||||
openai_model_name: Optional[str] = None,
|
||||
llm: Optional[Union[str, LLM]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
Args:
|
||||
n_iterations: Number of test iterations to run
|
||||
openai_model_name: (Deprecated) Name of OpenAI model to use for evaluation
|
||||
llm: LLM instance or model name to use for evaluation
|
||||
inputs: Optional dictionary of inputs to pass to the crew
|
||||
"""
|
||||
if openai_model_name:
|
||||
warnings.warn(
|
||||
"openai_model_name is deprecated and will be removed in future versions. Use llm parameter instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
test_crew = self.copy()
|
||||
model = llm if llm else openai_model_name
|
||||
|
||||
try:
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"Either llm or openai_model_name must be provided. Please provide either "
|
||||
"a custom LLM instance or an OpenAI model name."
|
||||
)
|
||||
if isinstance(model, LLM):
|
||||
if not hasattr(model, 'model'):
|
||||
raise ValueError("Provided LLM instance must have a 'model' attribute")
|
||||
elif isinstance(model, str):
|
||||
model = LLM(model=model)
|
||||
else:
|
||||
raise ValueError("LLM must be either a string model name or an LLM instance")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize LLM: {str(e)}")
|
||||
|
||||
self._test_execution_span = test_crew._telemetry.test_execution_span(
|
||||
test_crew,
|
||||
n_iterations,
|
||||
inputs,
|
||||
str(model), # type: ignore[arg-type]
|
||||
openai_model_name, # type: ignore[arg-type]
|
||||
) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(test_crew, model)
|
||||
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type]
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
data = f"""
|
||||
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
super().save(data, item.metadata, custom_key=custom_key)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
if custom_key:
|
||||
metadata.update({"custom_key": custom_key})
|
||||
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
|
||||
|
||||
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
|
||||
retrieving memories contextually, and preventing data leakage across logical boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: RAGStorage):
|
||||
@@ -16,10 +19,13 @@ class Memory:
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
if custom_key:
|
||||
metadata["custom_key"] = custom_key
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
@@ -28,7 +34,12 @@ class Memory:
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> List[Any]:
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
|
||||
)
|
||||
|
||||
@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
query = """
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
"""
|
||||
|
||||
params = [task_description]
|
||||
|
||||
if custom_key:
|
||||
query += " AND json_extract(metadata, '$.custom_key') = ?"
|
||||
params.append(custom_key)
|
||||
|
||||
query += f"""
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec
|
||||
(task_description,),
|
||||
)
|
||||
"""
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
|
||||
@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
response = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter
|
||||
)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
|
||||
@@ -26,20 +26,27 @@ class UserMemory(Memory):
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
super().save(data, metadata)
|
||||
super().save(data, metadata, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
results = self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict,
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.llm import LLM
|
||||
from rich.box import HEAVY_EDGE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -27,7 +23,7 @@ class CrewEvaluator:
|
||||
|
||||
Attributes:
|
||||
crew (Crew): The crew of agents to evaluate.
|
||||
llm (LLM): The language model to use for evaluating the performance of the agents.
|
||||
openai_model_name (str): The model to use for evaluating the performance of the agents (for now ONLY OpenAI accepted).
|
||||
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
|
||||
iteration (int): The current iteration of the evaluation.
|
||||
"""
|
||||
@@ -36,20 +32,12 @@ class CrewEvaluator:
|
||||
run_execution_times: defaultdict = defaultdict(list)
|
||||
iteration: int = 0
|
||||
|
||||
def __init__(self, crew, llm: Union[str, LLM]):
|
||||
def __init__(self, crew, openai_model_name: str):
|
||||
self.crew = crew
|
||||
try:
|
||||
self.llm = llm if isinstance(llm, LLM) else LLM(model=llm)
|
||||
if not hasattr(self.llm, 'model'):
|
||||
raise ValueError("Provided LLM instance must have a 'model' attribute")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize LLM: {str(e)}")
|
||||
self.openai_model_name = openai_model_name
|
||||
self._telemetry = Telemetry()
|
||||
self._setup_for_evaluating()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"CrewEvaluator(model={str(self.llm)}, iteration={self.iteration})"
|
||||
|
||||
def _setup_for_evaluating(self) -> None:
|
||||
"""Sets up the crew for evaluating."""
|
||||
for task in self.crew.tasks:
|
||||
@@ -63,7 +51,7 @@ class CrewEvaluator:
|
||||
),
|
||||
backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed",
|
||||
verbose=False,
|
||||
llm=self.llm,
|
||||
llm=self.openai_model_name,
|
||||
)
|
||||
|
||||
def _evaluation_task(
|
||||
@@ -193,7 +181,7 @@ class CrewEvaluator:
|
||||
self.crew,
|
||||
evaluation_result.pydantic.quality,
|
||||
current_task._execution_time,
|
||||
str(self.llm),
|
||||
self.openai_model_name,
|
||||
)
|
||||
self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality)
|
||||
self.run_execution_times[self.iteration].append(
|
||||
|
||||
@@ -13,7 +13,6 @@ import pytest
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.process import Process
|
||||
@@ -1124,7 +1123,7 @@ def test_kickoff_for_each_empty_input():
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headeruvs=["authorization"])
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_invalid_input():
|
||||
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
|
||||
|
||||
@@ -2813,10 +2812,10 @@ def test_conditional_should_execute():
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_with_custom_llm(kickoff_mock, copy_mock, crew_evaluator):
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
@@ -2828,107 +2827,23 @@ def test_crew_testing_with_custom_llm(kickoff_mock, copy_mock, crew_evaluator):
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
custom_llm = LLM(model="gpt-4o-mini")
|
||||
n_iterations = 2
|
||||
crew.test(n_iterations, llm=custom_llm)
|
||||
crew.test(n_iterations, openai_model_name="gpt-4o-mini", inputs={"topic": "AI"})
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls([mock.call(inputs=None), mock.call(inputs=None)])
|
||||
|
||||
# Verify CrewEvaluator was called with custom LLM
|
||||
crew_evaluator.assert_has_calls([
|
||||
mock.call(crew, custom_llm),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
])
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_backward_compatibility(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
crew_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(crew, "gpt-4o-mini"),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
]
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
n_iterations = 2
|
||||
with pytest.warns(DeprecationWarning, match="openai_model_name is deprecated"):
|
||||
crew.test(n_iterations, openai_model_name="gpt-4o-mini", inputs={"topic": "AI"})
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls([
|
||||
mock.call(inputs={"topic": "AI"}),
|
||||
mock.call(inputs={"topic": "AI"})
|
||||
])
|
||||
|
||||
# Verify CrewEvaluator was called with string model name
|
||||
crew_evaluator.assert_has_calls([
|
||||
mock.call(crew, mock.ANY),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
])
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_missing_llm(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
n_iterations = 2
|
||||
with pytest.raises(ValueError, match="Either llm or openai_model_name must be provided"):
|
||||
crew.test(n_iterations)
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_with_invalid_llm(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
# Test invalid LLM type
|
||||
with pytest.raises(ValueError, match="Failed to initialize LLM"):
|
||||
crew.test(n_iterations=2, llm={})
|
||||
|
||||
# Test LLM without model attribute
|
||||
class InvalidLLM:
|
||||
def __init__(self): pass
|
||||
with pytest.raises(ValueError, match="LLM must be either a string model name or an LLM instance"):
|
||||
crew.test(n_iterations=2, llm=InvalidLLM())
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_verbose_manager_agent():
|
||||
@@ -3210,4 +3125,4 @@ def test_multimodal_agent_live_image_analysis():
|
||||
# Verify we got a meaningful response
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 100 # Expecting a detailed analysis
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
57
tests/memory/custom_key_memory_test.py
Normal file
57
tests/memory/custom_key_memory_test.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||
|
||||
|
||||
def test_save_with_custom_key(short_term_memory):
|
||||
"""Test that save method correctly passes custom_key to storage"""
|
||||
with patch.object(short_term_memory.storage, 'save') as mock_save:
|
||||
short_term_memory.save(
|
||||
value="Test data",
|
||||
metadata={"task": "test_task"},
|
||||
agent="test_agent",
|
||||
custom_key="user123",
|
||||
)
|
||||
|
||||
called_args = mock_save.call_args[0]
|
||||
called_kwargs = mock_save.call_args[1]
|
||||
|
||||
assert "custom_key" in called_args[1]
|
||||
assert called_args[1]["custom_key"] == "user123"
|
||||
|
||||
|
||||
def test_search_with_custom_key(short_term_memory):
|
||||
"""Test that search method correctly passes custom_key to storage"""
|
||||
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
|
||||
|
||||
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
|
||||
results = short_term_memory.search("test query", custom_key="user123")
|
||||
|
||||
mock_search.assert_called_once()
|
||||
filter_arg = mock_search.call_args[1].get('filter')
|
||||
assert filter_arg == {"custom_key": {"$eq": "user123"}}
|
||||
assert results == expected_results
|
||||
@@ -23,7 +23,7 @@ class TestCrewEvaluator:
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
return CrewEvaluator(crew, llm="gpt-4o-mini")
|
||||
return CrewEvaluator(crew, openai_model_name="gpt-4o-mini")
|
||||
|
||||
def test_setup_for_evaluating(self, crew_planner):
|
||||
crew_planner._setup_for_evaluating()
|
||||
|
||||
Reference in New Issue
Block a user