diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 999d1d800..a3aeb5f13 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -10,6 +10,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context from crewai.llm import LLM @@ -131,6 +132,10 @@ class Agent(BaseAgent): default=None, description="Knowledge sources for the agent.", ) + knowledge_config: Optional[KnowledgeConfig] = Field( + default=None, + description="Configuration for knowledge querying (results_limit, score_threshold, metadata_filter).", + ) _knowledge: Optional[Knowledge] = PrivateAttr( default=None, ) @@ -306,7 +311,16 @@ class Agent(BaseAgent): task_prompt += self.i18n.slice("memory").format(memory=memory) if self._knowledge: - agent_knowledge_snippets = self._knowledge.query([task.prompt()]) + kc = self.knowledge_config + query_kwargs: Dict[str, Any] = {} + if kc: + query_kwargs["limit"] = kc.results_limit + query_kwargs["score_threshold"] = kc.score_threshold + if kc.metadata_filter: + query_kwargs["filter"] = kc.metadata_filter + agent_knowledge_snippets = self._knowledge.query( + [task.prompt()], **query_kwargs + ) if agent_knowledge_snippets: agent_knowledge_context = extract_knowledge_context( agent_knowledge_snippets @@ -315,7 +329,16 @@ class Agent(BaseAgent): task_prompt += agent_knowledge_context if self.crew: - knowledge_snippets = self.crew.query_knowledge([task.prompt()]) + kc = self.knowledge_config + query_kwargs = {} + if kc: + query_kwargs["limit"] = kc.results_limit + query_kwargs["score_threshold"] = kc.score_threshold + if kc.metadata_filter: + query_kwargs["filter"] = kc.metadata_filter + knowledge_snippets = self.crew.query_knowledge( + [task.prompt()], **query_kwargs + ) if knowledge_snippets: crew_knowledge_context = extract_knowledge_context(knowledge_snippets) if crew_knowledge_context: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d488783ea..26f636bcd 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -978,9 +978,20 @@ class Crew(BaseModel): result = self._execute_tasks(self.tasks, start_index, True) return result - def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]: + def query_knowledge( + self, + query: List[str], + limit: int = 3, + filter: Optional[dict] = None, + score_threshold: float = 0.35, + ) -> Union[List[Dict[str, Any]], None]: if self._knowledge: - return self._knowledge.query(query) + return self._knowledge.query( + query, + limit=limit, + filter=filter, + score_threshold=score_threshold, + ) return None def copy(self): diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index f9f55a517..e9075e9d3 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -45,7 +45,13 @@ class Knowledge(BaseModel): source.storage = self.storage source.add() - def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]: + def query( + self, + query: List[str], + limit: int = 3, + filter: Optional[dict] = None, + score_threshold: float = 0.35, + ) -> List[Dict[str, Any]]: """ Query across all knowledge sources to find the most relevant information. Returns the top_k most relevant chunks. @@ -54,6 +60,8 @@ class Knowledge(BaseModel): results = self.storage.search( query, limit, + filter=filter, + score_threshold=score_threshold, ) return results diff --git a/src/crewai/knowledge/knowledge_config.py b/src/crewai/knowledge/knowledge_config.py new file mode 100644 index 000000000..53e920746 --- /dev/null +++ b/src/crewai/knowledge/knowledge_config.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class KnowledgeConfig(BaseModel): + """Configuration for knowledge querying behavior. + + Attributes: + results_limit: Maximum number of results to return from a knowledge query. + score_threshold: Minimum relevance score for results. + metadata_filter: Metadata filter dict passed to the vector store query. + """ + + results_limit: int = Field( + default=3, + description="Maximum number of results to return from a knowledge query.", + ) + score_threshold: float = Field( + default=0.35, + description="Minimum relevance score for results.", + ) + metadata_filter: Optional[Dict[str, Any]] = Field( + default=None, + description="Metadata filter dict passed to the vector store query.", + ) diff --git a/src/crewai/knowledge/source/base_knowledge_source.py b/src/crewai/knowledge/source/base_knowledge_source.py index 88c3ab360..f17b740cd 100644 --- a/src/crewai/knowledge/source/base_knowledge_source.py +++ b/src/crewai/knowledge/source/base_knowledge_source.py @@ -46,4 +46,4 @@ class BaseKnowledgeSource(BaseModel, ABC): Save the documents to the storage. This method should be called after the chunks and embeddings are generated. """ - self.storage.save(self.chunks) + self.storage.save(self.chunks, metadata=self.metadata if self.metadata else None) diff --git a/tests/knowledge/knowledge_test.py b/tests/knowledge/knowledge_test.py index 366067587..ffd773a4d 100644 --- a/tests/knowledge/knowledge_test.py +++ b/tests/knowledge/knowledge_test.py @@ -1,11 +1,13 @@ """Test Knowledge creation and querying functionality.""" from pathlib import Path -from typing import List, Union -from unittest.mock import patch +from typing import Any, Dict, List, Optional, Union +from unittest.mock import MagicMock, call, patch import pytest +from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.crew_docling_source import CrewDoclingSource from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource @@ -584,3 +586,171 @@ def test_docling_source_with_local_file(): docling_source = CrewDoclingSource(file_paths=[pdf_path]) assert docling_source.file_paths == [pdf_path] assert docling_source.content is not None + + +# --- Tests for KnowledgeConfig and metadata_filter support --- + + +class TestKnowledgeConfig: + """Tests for the KnowledgeConfig model.""" + + def test_default_values(self): + config = KnowledgeConfig() + assert config.results_limit == 3 + assert config.score_threshold == 0.35 + assert config.metadata_filter is None + + def test_custom_values(self): + config = KnowledgeConfig( + results_limit=10, + score_threshold=0.5, + metadata_filter={"task": "translation"}, + ) + assert config.results_limit == 10 + assert config.score_threshold == 0.5 + assert config.metadata_filter == {"task": "translation"} + + def test_partial_override(self): + config = KnowledgeConfig(metadata_filter={"source": "docs"}) + assert config.results_limit == 3 + assert config.score_threshold == 0.35 + assert config.metadata_filter == {"source": "docs"} + + +class TestKnowledgeQueryForwardsParams: + """Tests that Knowledge.query() forwards filter and score_threshold to storage.""" + + def test_query_forwards_metadata_filter(self): + mock_storage = MagicMock() + mock_storage.search.return_value = [] + mock_storage.initialize_knowledge_storage.return_value = None + + with patch( + "crewai.knowledge.knowledge.KnowledgeStorage", + return_value=mock_storage, + ): + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=mock_storage, + ) + + knowledge.query( + ["test query"], + limit=5, + filter={"task": "translation"}, + score_threshold=0.6, + ) + + mock_storage.search.assert_called_once_with( + ["test query"], + 5, + filter={"task": "translation"}, + score_threshold=0.6, + ) + + def test_query_defaults_without_filter(self): + mock_storage = MagicMock() + mock_storage.search.return_value = [] + mock_storage.initialize_knowledge_storage.return_value = None + + with patch( + "crewai.knowledge.knowledge.KnowledgeStorage", + return_value=mock_storage, + ): + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=mock_storage, + ) + + knowledge.query(["test query"]) + + mock_storage.search.assert_called_once_with( + ["test query"], + 3, + filter=None, + score_threshold=0.35, + ) + + +class TestBaseKnowledgeSourceSavesMetadata: + """Tests that _save_documents passes self.metadata to storage.save().""" + + def test_save_documents_passes_metadata(self): + source = StringKnowledgeSource( + content="Hello world", + metadata={"task": "translation"}, + ) + source.storage = MagicMock() + source.chunks = ["Hello world"] + source._save_documents() + + source.storage.save.assert_called_once_with( + ["Hello world"], + metadata={"task": "translation"}, + ) + + def test_save_documents_no_metadata(self): + source = StringKnowledgeSource(content="Hello world") + source.storage = MagicMock() + source.chunks = ["Hello world"] + source._save_documents() + + source.storage.save.assert_called_once_with( + ["Hello world"], + metadata=None, + ) + + +class TestCrewQueryKnowledgeForwardsParams: + """Tests that Crew.query_knowledge() forwards filter and score_threshold.""" + + def _make_crew_stub(self, knowledge=None): + """Create a minimal stub that satisfies Crew.query_knowledge.""" + stub = MagicMock() + stub._knowledge = knowledge + # Bind the real method to the stub + from crewai.crew import Crew + + stub.query_knowledge = Crew.query_knowledge.__get__(stub, type(stub)) + return stub + + def test_crew_query_knowledge_with_filter(self): + mock_knowledge = MagicMock() + mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}] + crew = self._make_crew_stub(knowledge=mock_knowledge) + + crew.query_knowledge( + ["test query"], + limit=5, + filter={"task": "translation"}, + score_threshold=0.6, + ) + + mock_knowledge.query.assert_called_once_with( + ["test query"], + limit=5, + filter={"task": "translation"}, + score_threshold=0.6, + ) + + def test_crew_query_knowledge_defaults(self): + mock_knowledge = MagicMock() + mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}] + crew = self._make_crew_stub(knowledge=mock_knowledge) + + crew.query_knowledge(["test query"]) + + mock_knowledge.query.assert_called_once_with( + ["test query"], + limit=3, + filter=None, + score_threshold=0.35, + ) + + def test_crew_query_knowledge_no_knowledge(self): + crew = self._make_crew_stub(knowledge=None) + + result = crew.query_knowledge(["test query"]) + assert result is None