mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- ignore line length errors globally - migrate knowledge/memory and crew query_knowledge to `SearchResult` - remove legacy chromadb utils; fix empty metadata handling - restore openai as default embedding provider; support instance-specific clients - update and fix tests for `SearchResult` migration and rag changes
This commit is contained in:
@@ -9,19 +9,19 @@ import pytest
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.tools import tool
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
@@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
@tool
|
||||
def comapny_customer_data() -> float:
|
||||
def comapny_customer_data() -> str:
|
||||
"""Useful for getting customer related data."""
|
||||
return "The company has 42 customers"
|
||||
|
||||
@@ -559,9 +559,9 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input."
|
||||
)
|
||||
assert (
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
assert expected_message in output, (
|
||||
f"Expected message not found in output. Output was: {output}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -602,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert (
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -880,7 +880,7 @@ def test_agent_step_callback():
|
||||
with patch.object(StepCallback, "callback") as callback:
|
||||
|
||||
@tool
|
||||
def learn_about_AI() -> str:
|
||||
def learn_about_ai() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -888,7 +888,7 @@ def test_agent_step_callback():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_AI],
|
||||
tools=[learn_about_ai],
|
||||
step_callback=StepCallback().callback,
|
||||
)
|
||||
|
||||
@@ -910,7 +910,7 @@ def test_agent_function_calling_llm():
|
||||
llm = "gpt-4o"
|
||||
|
||||
@tool
|
||||
def learn_about_AI() -> str:
|
||||
def learn_about_ai() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -918,7 +918,7 @@ def test_agent_function_calling_llm():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_AI],
|
||||
tools=[learn_about_ai],
|
||||
llm="gpt-4o",
|
||||
max_iter=2,
|
||||
function_calling_llm=llm,
|
||||
@@ -1356,7 +1356,7 @@ def test_agent_training_handler(crew_training_handler):
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
|
||||
f"{agent.id!s}": {"0": {"human_feedback": "good"}}
|
||||
}
|
||||
|
||||
result = agent._training_handler(task_prompt=task_prompt)
|
||||
@@ -1473,7 +1473,7 @@ def test_agent_with_custom_stop_words():
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
|
||||
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
|
||||
assert all(word in agent.llm.stop for word in stop_words)
|
||||
assert "\nObservation:" in agent.llm.stop
|
||||
|
||||
@@ -1530,7 +1530,7 @@ def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@@ -1830,11 +1830,11 @@ def test_agent_execute_task_with_ollama():
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with patch("crewai.knowledge") as mock_knowledge:
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.search.return_value = [{"content": content}]
|
||||
MockKnowledge.add_sources.return_value = [string_source]
|
||||
mock_knowledge.add_sources.return_value = [string_source]
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1863,12 +1863,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1898,15 +1911,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -1935,10 +1960,16 @@ def test_agent_with_knowledge_sources_extensive_role():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
|
||||
) as mock_save,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
mock_save.return_value = None
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
@@ -1968,8 +1999,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
with patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
|
||||
autospec=True,
|
||||
) as MockKnowledgeSource:
|
||||
mock_knowledge_source_instance = MockKnowledgeSource.return_value
|
||||
) as mock_knowledge_source:
|
||||
mock_knowledge_source_instance = mock_knowledge_source.return_value
|
||||
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
|
||||
mock_knowledge_source_instance.sources = [string_source]
|
||||
|
||||
@@ -1983,9 +2014,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledgeStorage:
|
||||
mock_knowledge_storage = MockKnowledgeStorage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage
|
||||
) as mock_knowledge_storage:
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
@@ -2004,11 +2035,30 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -2270,7 +2320,26 @@ def test_get_knowledge_search_query():
|
||||
i18n = I18N()
|
||||
task_prompt = task.prompt()
|
||||
|
||||
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query:
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
mock_get_query.return_value = "Capital of France"
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
@@ -2312,9 +2381,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
FileReadTool,
|
||||
SerperDevTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
@@ -2347,7 +2416,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
tool_action = EnterpriseActionTool(
|
||||
name="test_name",
|
||||
description="test_description",
|
||||
enterprise_action_token="test_token",
|
||||
enterprise_action_token="test_token", # noqa: S106
|
||||
action_name="test_action_name",
|
||||
action_schema={"test": "test"},
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Test Knowledge creation and querying functionality."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +22,7 @@ def mock_vector_db():
|
||||
instance = mock.return_value
|
||||
instance.query.return_value = [
|
||||
{
|
||||
"context": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"content": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite color?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("blue" in result["context"].lower() for result in results)
|
||||
assert any("blue" in result["content"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
|
||||
# Mock the vector db query response
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon has a dog named Max.", "score": 0.9}
|
||||
{"content": "Brandon has a dog named Max.", "score": 0.9}
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("max" in result["context"].lower() for result in results)
|
||||
assert any("max" in result["content"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What sport does Brandon like?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("basketball" in result["context"].lower() for result in results)
|
||||
assert any("basketball" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
|
||||
]
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon lives in New York.", "score": 0.9}
|
||||
{"content": "Brandon lives in New York.", "score": 0.9}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What city does he reside in?"
|
||||
results = mock_vector_db.query(query)
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("new york" in result["context"].lower() for result in results)
|
||||
assert any("new york" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Combine string and file sources
|
||||
mock_vector_db.sources = string_sources + file_sources
|
||||
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("the alchemist" in result["context"].lower() for result in results)
|
||||
assert any("the alchemist" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
)
|
||||
mock_vector_db.sources = [pdf_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
{"content": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"crewai create crew latest-ai-development" in result["context"].lower()
|
||||
"crewai create crew latest-ai-development" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [csv_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["context"] for result in results)
|
||||
assert any("30" in result["content"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [json_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
{"content": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("los angeles" in result["context"].lower() for result in results)
|
||||
assert any("los angeles" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
"""Test ExcelKnowledgeSource with a simple Excel file."""
|
||||
|
||||
# Create an Excel file with sample data
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
excel_data = {
|
||||
"Name": ["Brandon", "Alice", "Bob"],
|
||||
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [excel_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["context"] for result in results)
|
||||
assert any("30" in result["content"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
|
||||
mock_vector_db.sources = [docling_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What is reward hacking?"
|
||||
results = mock_vector_db.query(query)
|
||||
assert any("reward hacking" in result["context"].lower() for result in results)
|
||||
assert any("reward hacking" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_multiple_docling_sources():
|
||||
urls: List[Union[Path, str]] = [
|
||||
def test_multiple_docling_sources() -> None:
|
||||
urls: list[Path | str] = [
|
||||
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
||||
]
|
||||
|
||||
191
tests/knowledge/test_knowledge_searchresult.py
Normal file
191
tests/knowledge/test_knowledge_searchresult.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Tests for Knowledge SearchResult type conversion and integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
|
||||
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
|
||||
StringKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
|
||||
def test_knowledge_query_returns_searchresult() -> None:
|
||||
"""Test that Knowledge.query returns SearchResult format."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = [
|
||||
{
|
||||
"content": "AI is fascinating",
|
||||
"score": 0.9,
|
||||
"metadata": {"source": "doc1"},
|
||||
},
|
||||
{
|
||||
"content": "Machine learning rocks",
|
||||
"score": 0.8,
|
||||
"metadata": {"source": "doc2"},
|
||||
},
|
||||
]
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test knowledge content")]
|
||||
knowledge = Knowledge(collection_name="test_collection", sources=sources)
|
||||
|
||||
results = knowledge.query(
|
||||
["AI technology"], results_limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
mock_storage.search.assert_called_once_with(
|
||||
["AI technology"], limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 2
|
||||
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
assert "content" in result
|
||||
assert "score" in result
|
||||
assert "metadata" in result
|
||||
|
||||
assert results[0]["content"] == "AI is fascinating"
|
||||
assert results[0]["score"] == 0.9
|
||||
assert results[1]["content"] == "Machine learning rocks"
|
||||
assert results[1]["score"] == 0.8
|
||||
|
||||
|
||||
def test_knowledge_query_with_empty_results() -> None:
|
||||
"""Test Knowledge.query with empty search results."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = []
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="empty_test", sources=sources)
|
||||
|
||||
results = knowledge.query(["nonexistent query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_searchresult() -> None:
|
||||
"""Test extract_knowledge_context works with SearchResult format."""
|
||||
search_results = [
|
||||
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
|
||||
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
|
||||
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Python is great for AI" in context
|
||||
assert "Machine learning algorithms" in context
|
||||
assert "Deep learning frameworks" in context
|
||||
|
||||
expected_content = (
|
||||
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
|
||||
)
|
||||
assert expected_content in context
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_empty_content() -> None:
|
||||
"""Test extract_knowledge_context handles empty or invalid content."""
|
||||
search_results = [
|
||||
{"content": "", "score": 0.5, "metadata": {}},
|
||||
{"content": None, "score": 0.4, "metadata": {}},
|
||||
{"score": 0.3, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
def test_extract_knowledge_context_filters_invalid_results() -> None:
|
||||
"""Test that extract_knowledge_context filters out invalid results."""
|
||||
search_results: list[dict[str, Any] | None] = [
|
||||
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
|
||||
{"content": "", "score": 0.8, "metadata": {}},
|
||||
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
|
||||
None,
|
||||
{"content": None, "score": 0.6, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Valid content 1" in context
|
||||
assert "Valid content 2" in context
|
||||
assert context.count("\n") == 1
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_storage_exception_handling(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge handles storage exceptions gracefully."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.side_effect = Exception("Storage error")
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="error_test", sources=sources)
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.storage = None
|
||||
knowledge.query(["test query"])
|
||||
|
||||
|
||||
def test_knowledge_add_sources_integration() -> None:
|
||||
"""Test Knowledge.add_sources integrates properly with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [
|
||||
StringKnowledgeSource(content="Content 1"),
|
||||
StringKnowledgeSource(content="Content 2"),
|
||||
]
|
||||
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
|
||||
|
||||
knowledge.add_sources()
|
||||
|
||||
for source in sources:
|
||||
assert source.storage == mock_storage
|
||||
|
||||
|
||||
def test_knowledge_reset_integration() -> None:
|
||||
"""Test Knowledge.reset integrates with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="reset_test", sources=sources)
|
||||
|
||||
knowledge.reset()
|
||||
|
||||
mock_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_reset_without_storage(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge.reset raises error when storage is None."""
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
|
||||
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.reset()
|
||||
196
tests/knowledge/test_knowledge_storage_integration.py
Normal file
196
tests/knowledge/test_knowledge_storage_integration.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Integration tests for KnowledgeStorage RAG client migration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_knowledge_storage_uses_rag_client(
|
||||
mock_get_embedding: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that KnowledgeStorage properly integrates with RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_create_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
|
||||
storage = KnowledgeStorage(
|
||||
embedder=embedder_config, collection_name="test_knowledge"
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
results = storage.search(["test query"], limit=5, score_threshold=0.3)
|
||||
|
||||
mock_get_client.assert_not_called()
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name="knowledge_test_knowledge",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter=None,
|
||||
score_threshold=0.3,
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], dict)
|
||||
assert "content" in results[0]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
|
||||
"""Test that collection names are properly prefixed."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage(collection_name="custom_knowledge")
|
||||
storage.search(["test"], limit=1)
|
||||
|
||||
mock_client.search.assert_called_once()
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage_default = KnowledgeStorage()
|
||||
storage_default.search(["test"], limit=1)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test document saving through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_docs")
|
||||
documents = ["Document 1 content", "Document 2 content"]
|
||||
|
||||
storage.save(documents)
|
||||
|
||||
mock_client.get_or_create_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_docs"
|
||||
)
|
||||
mock_client.add_documents.assert_called_once()
|
||||
|
||||
call_kwargs = mock_client.add_documents.call_args.kwargs
|
||||
added_docs = call_kwargs["documents"]
|
||||
assert len(added_docs) == 2
|
||||
assert added_docs[0]["content"] == "Document 1 content"
|
||||
assert added_docs[1]["content"] == "Document 2 content"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_reset_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test collection reset through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_reset")
|
||||
storage.reset()
|
||||
|
||||
mock_client.delete_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_reset"
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_search_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test error handling during search operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = Exception("RAG client error")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="error_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_embedding_configuration_flow(
|
||||
mock_get_embedding: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that embedding configuration flows properly to RAG client."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_get_embedding.return_value = mock_embedding_func
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
|
||||
"""Test that query list is properly converted to string."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
storage.search(["single query"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "single query"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["query one", "query two"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "query one query two"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test metadata filter parameter handling."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
metadata_filter = {"category": "technical", "priority": "high"}
|
||||
storage.search(["test"], metadata_filter=metadata_filter)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] == metadata_filter
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["test"], metadata_filter=None)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] is None
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test specific handling of dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_test")
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
storage.save(["test document"])
|
||||
@@ -1,19 +1,20 @@
|
||||
from unittest.mock import patch, ANY
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.task import Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -38,22 +39,23 @@ def short_term_memory():
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory):
|
||||
|
||||
expected_result = [
|
||||
{
|
||||
"context": memory.data,
|
||||
"content": memory.data,
|
||||
"metadata": {"agent": "test_agent"},
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
with patch.object(ShortTermMemory, "search", return_value=expected_result):
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["content"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
|
||||
@@ -285,6 +285,43 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
def test_add_documents_all_without_metadata(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2"},
|
||||
{"content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] is None
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
@@ -358,6 +395,31 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_without_metadata(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
|
||||
95
tests/rag/chromadb/test_utils.py
Normal file
95
tests/rag/chromadb/test_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_is_ipv4_pattern,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
"""Test suite for ChromaDB utility functions."""
|
||||
|
||||
def test_sanitize_collection_name_long_name(self) -> None:
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = _sanitize_collection_name(long_name)
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self) -> None:
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = _sanitize_collection_name(special_chars)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_short_name(self) -> None:
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = _sanitize_collection_name(short_name)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self) -> None:
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = _sanitize_collection_name(bad_ends)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_none(self) -> None:
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = _sanitize_collection_name(None)
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = _sanitize_collection_name(ipv4)
|
||||
assert sanitized.startswith("ip_")
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_is_ipv4_pattern(self) -> None:
|
||||
"""Test IPv4 pattern detection."""
|
||||
assert _is_ipv4_pattern("192.168.1.1") is True
|
||||
assert _is_ipv4_pattern("not.an.ip.address") is False
|
||||
|
||||
def test_sanitize_collection_name_properties(self) -> None:
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases: list[str] = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = _sanitize_collection_name(test_case)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_empty_string(self) -> None:
|
||||
"""Test sanitizing an empty string."""
|
||||
sanitized = _sanitize_collection_name("")
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_whitespace_only(self) -> None:
|
||||
"""Test sanitizing a string with only whitespace."""
|
||||
sanitized = _sanitize_collection_name(" ")
|
||||
assert (
|
||||
sanitized == "a__z"
|
||||
) # Spaces become underscores, padded to meet requirements
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
218
tests/rag/test_error_handling.py
Normal file
218
tests/rag/test_error_handling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for RAG client error handling scenarios."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles RAG client connection failures."""
|
||||
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="connection_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles search timeouts gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = TimeoutError("Search operation timed out")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="timeout_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles missing collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = ValueError(
|
||||
"Collection 'knowledge_missing' does not exist"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="missing_collection")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles invalid embedding configurations."""
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
|
||||
) as mock_get_embedding:
|
||||
mock_get_embedding.side_effect = ValueError(
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception(
|
||||
"attempt to write a readonly database"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="readonly_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test KnowledgeStorage reset handles non-existent collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="nonexistent_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "valid result", "metadata": {"source": "test"}},
|
||||
{"invalid": "missing content field", "metadata": {"source": "test"}},
|
||||
None,
|
||||
{"content": None, "metadata": {"source": "test"}},
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="malformed_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles network interruptions during operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_client.search.side_effect = [
|
||||
ConnectionError("Network interruption"),
|
||||
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="network_test")
|
||||
|
||||
first_attempt = storage.search(["test query"])
|
||||
assert first_attempt == []
|
||||
|
||||
mock_client.search.side_effect = None
|
||||
mock_client.search.return_value = [
|
||||
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
second_attempt = storage.search(["test query"])
|
||||
assert len(second_attempt) == 1
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test detailed handling of embedding dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception(
|
||||
"Embedding dimension mismatch: expected 384, got 1536"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
storage.save(["test document"])
|
||||
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
assert "Make sure you're using the same embedding model" in str(exc_info.value)
|
||||
assert "crewai reset-memories -a" in str(exc_info.value)
|
||||
@@ -1,8 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.client.main import MemoryClient
|
||||
from mem0.memory.main import Memory
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
@@ -13,6 +12,67 @@ class MockCrew:
|
||||
self.agents = [MagicMock(role="Test Agent")]
|
||||
|
||||
|
||||
# Test data constants
|
||||
SYSTEM_CONTENT = (
|
||||
"You are Friendly chatbot assistant. You are a kind and "
|
||||
"knowledgeable chatbot assistant. You excel at understanding user needs, "
|
||||
"providing helpful responses, and maintaining engaging conversations. "
|
||||
"You remember previous interactions to provide a personalized experience.\n"
|
||||
"Your personal goal is: Engage in useful and interesting conversations "
|
||||
"with users while remembering context.\n"
|
||||
"To give my best complete final answer to the task respond using the exact "
|
||||
"following format:\n\n"
|
||||
"Thought: I now can give a great answer\n"
|
||||
"Final Answer: Your final answer must be the great and the most complete "
|
||||
"as possible, it must be outcome described.\n\n"
|
||||
"I MUST use these formats, my job depends on it!"
|
||||
)
|
||||
|
||||
USER_CONTENT = (
|
||||
"\nCurrent Task: Respond to user conversation. User message: "
|
||||
"What do you know about me?\n\n"
|
||||
"This is the expected criteria for your final answer: Contextually "
|
||||
"appropriate, helpful, and friendly response.\n"
|
||||
"you MUST return the actual complete content as the final answer, "
|
||||
"not a summary.\n\n"
|
||||
"# Useful context: \nExternal memories:\n"
|
||||
"- User is from India\n"
|
||||
"- User is interested in the solar system\n"
|
||||
"- User name is Vidit Ostwal\n"
|
||||
"- User is interested in French cuisine\n\n"
|
||||
"Begin! This is VERY important to you, use the tools available and give "
|
||||
"your best Final Answer, your job depends on it!\n\n"
|
||||
"Thought:"
|
||||
)
|
||||
|
||||
ASSISTANT_CONTENT = (
|
||||
"I now can give a great answer \n"
|
||||
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
TEST_DESCRIPTION = (
|
||||
"Respond to user conversation. User message: What do you know about me?"
|
||||
)
|
||||
|
||||
# Extracted content (after processing by _get_user_message and _get_assistant_message)
|
||||
EXTRACTED_USER_CONTENT = "What do you know about me?"
|
||||
EXTRACTED_ASSISTANT_CONTENT = (
|
||||
"Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
"""Fixture to create a mock Memory instance"""
|
||||
@@ -24,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# Patch the Memory class to return our mock
|
||||
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
|
||||
with patch(
|
||||
"mem0.Memory.from_config", return_value=mock_mem0_memory
|
||||
) as mock_from_config:
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mock_vector_store",
|
||||
@@ -55,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True}
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"local_mem0_config": config,
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage, mock_from_config, config
|
||||
@@ -83,28 +152,31 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True
|
||||
}
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
|
||||
def mem0_storage_with_memory_client_using_explictly_config(
|
||||
mock_mem0_memory_client, mock_mem0_memory
|
||||
):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch both MemoryClient and Memory to prevent actual initialization
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
|
||||
|
||||
with (
|
||||
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
|
||||
):
|
||||
crew = MockCrew()
|
||||
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
|
||||
|
||||
@@ -138,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
mock_mem0_memory_client.update_project = MagicMock()
|
||||
|
||||
new_categories = [
|
||||
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
|
||||
{
|
||||
"lifestyle_management_concerns": (
|
||||
"Tracks daily routines, habits, hobbies and interests "
|
||||
"including cooking, time management and work-life balance"
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
crew = MockCrew()
|
||||
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories
|
||||
}
|
||||
config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories,
|
||||
}
|
||||
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
_ = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
@@ -159,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
@@ -168,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent'
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
|
||||
mem0_storage.crew.agents = [
|
||||
MagicMock(role="Test Agent"),
|
||||
MagicMock(role="Test Agent 2"),
|
||||
MagicMock(role="Test Agent 3"),
|
||||
]
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
|
||||
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
def test_save_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
includes="include1",
|
||||
excludes="exclude1",
|
||||
output_format='v1.1',
|
||||
user_id='test_user',
|
||||
agent_id='Test_Agent'
|
||||
output_format="v1.1",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -238,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="test_user",
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
def test_search_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -259,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
|
||||
limit=5,
|
||||
metadata={"type": "short_term"},
|
||||
user_id="test_user",
|
||||
version='v2',
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
output_format='v1.1',
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
output_format="v1.1",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
@@ -275,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH"
|
||||
}
|
||||
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
assert mem0_storage.infer is True
|
||||
|
||||
|
||||
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
@@ -293,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
mem0_storage.save("test memory", {"key": "value"})
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': 'test memory'}],
|
||||
[{"role": "assistant", "content": "test memory"}],
|
||||
infer=True,
|
||||
metadata={"type": "external", "key": "value"},
|
||||
agent_id="agent-123",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_agent_entity():
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
@@ -314,22 +466,29 @@ def test_search_method_with_agent_entity():
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
|
||||
mem0_storage = Mem0Storage(
|
||||
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
|
||||
)
|
||||
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -337,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id():
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id='user-123',
|
||||
user_id="user-123",
|
||||
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.chromadb import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
is_ipv4_pattern,
|
||||
sanitize_collection_name,
|
||||
create_persistent_client,
|
||||
)
|
||||
|
||||
|
||||
def persistent_client_worker(path, queue):
|
||||
try:
|
||||
create_persistent_client(path=path)
|
||||
queue.put(None)
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
|
||||
class TestChromadbUtils(unittest.TestCase):
|
||||
def test_sanitize_collection_name_long_name(self):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = sanitize_collection_name(long_name)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self):
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = sanitize_collection_name(special_chars)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_short_name(self):
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = sanitize_collection_name(short_name)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self):
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = sanitize_collection_name(bad_ends)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_none(self):
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = sanitize_collection_name(None)
|
||||
self.assertEqual(sanitized, "default_collection")
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self):
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = sanitize_collection_name(ipv4)
|
||||
self.assertTrue(sanitized.startswith("ip_"))
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_is_ipv4_pattern(self):
|
||||
"""Test IPv4 pattern detection."""
|
||||
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
||||
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
||||
|
||||
def test_sanitize_collection_name_properties(self):
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = sanitize_collection_name(test_case)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_create_persistent_client_passes_args(self):
|
||||
with patch(
|
||||
"crewai.utilities.chromadb.PersistentClient"
|
||||
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_instance = MagicMock()
|
||||
mock_persistent_client.return_value = mock_instance
|
||||
|
||||
settings = Settings(allow_reset=True)
|
||||
client = create_persistent_client(path=tmpdir, settings=settings)
|
||||
|
||||
mock_persistent_client.assert_called_once_with(
|
||||
path=tmpdir, settings=settings
|
||||
)
|
||||
self.assertIs(client, mock_instance)
|
||||
|
||||
def test_create_persistent_client_process_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = [
|
||||
multiprocessing.Process(
|
||||
target=persistent_client_worker, args=(tmpdir, queue)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
errors = [queue.get(timeout=5) for _ in processes]
|
||||
self.assertTrue(all(err is None for err in errors))
|
||||
@@ -29,13 +29,15 @@ def mock_knowledge_source():
|
||||
"""
|
||||
return StringKnowledgeSource(content=content)
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
|
||||
def test_knowledge_included_in_planning(mock_chroma):
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
def test_knowledge_included_in_planning(mock_get_client):
|
||||
"""Test that verifies knowledge sources are properly included in planning."""
|
||||
# Mock ChromaDB collection
|
||||
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
|
||||
mock_collection.add.return_value = None
|
||||
|
||||
# Mock RAG client
|
||||
mock_client = mock_get_client.return_value
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.return_value = None
|
||||
|
||||
# Create an agent with knowledge
|
||||
agent = Agent(
|
||||
role="AI Researcher",
|
||||
@@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma):
|
||||
StringKnowledgeSource(
|
||||
content="AI systems require careful training and validation."
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
task = Task(
|
||||
description="Explain the basics of AI systems",
|
||||
expected_output="A clear explanation of AI fundamentals",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew planner
|
||||
@@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma):
|
||||
task_summary = planner._create_tasks_summary()
|
||||
|
||||
# Verify that knowledge is included in planning when present
|
||||
assert "AI systems require careful training" in task_summary, \
|
||||
assert "AI systems require careful training" in task_summary, (
|
||||
"Knowledge content should be present in task summary when knowledge exists"
|
||||
assert '"agent_knowledge"' in task_summary, \
|
||||
)
|
||||
assert '"agent_knowledge"' in task_summary, (
|
||||
"agent_knowledge field should be present in task summary when knowledge exists"
|
||||
)
|
||||
|
||||
# Verify that knowledge is properly formatted
|
||||
assert isinstance(task.agent.knowledge_sources, list), \
|
||||
assert isinstance(task.agent.knowledge_sources, list), (
|
||||
"Knowledge sources should be stored in a list"
|
||||
assert len(task.agent.knowledge_sources) > 0, \
|
||||
)
|
||||
assert len(task.agent.knowledge_sources) > 0, (
|
||||
"At least one knowledge source should be present"
|
||||
assert task.agent.knowledge_sources[0].content in task_summary, \
|
||||
)
|
||||
assert task.agent.knowledge_sources[0].content in task_summary, (
|
||||
"Knowledge source content should be included in task summary"
|
||||
)
|
||||
|
||||
# Verify that other expected components are still present
|
||||
assert task.description in task_summary, \
|
||||
assert task.description in task_summary, (
|
||||
"Task description should be present in task summary"
|
||||
assert task.expected_output in task_summary, \
|
||||
)
|
||||
assert task.expected_output in task_summary, (
|
||||
"Expected output should be present in task summary"
|
||||
assert agent.role in task_summary, \
|
||||
"Agent role should be present in task summary"
|
||||
)
|
||||
assert agent.role in task_summary, "Agent role should be present in task summary"
|
||||
|
||||
Reference in New Issue
Block a user