Compare commits

...

1 Commits

Author SHA1 Message Date
Devin AI
4f72839fea Fix #5805: Add metadata_filter support to Knowledge querying pipeline
- Add KnowledgeConfig class with results_limit, score_threshold, metadata_filter
- Add knowledge_config field to Agent
- Update Knowledge.query() to forward filter and score_threshold to storage
- Update Crew.query_knowledge() to accept and forward filter params
- Fix BaseKnowledgeSource._save_documents() to pass self.metadata to storage
- Wire Agent.execute_task() to use knowledge_config for both agent and crew queries
- Add 10 tests covering all changes

Co-Authored-By: João <joao@crewai.com>
2026-05-14 08:16:13 +00:00
6 changed files with 246 additions and 8 deletions

View File

@@ -10,6 +10,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
@@ -131,6 +132,10 @@ class Agent(BaseAgent):
default=None,
description="Knowledge sources for the agent.",
)
knowledge_config: Optional[KnowledgeConfig] = Field(
default=None,
description="Configuration for knowledge querying (results_limit, score_threshold, metadata_filter).",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)
@@ -306,7 +311,16 @@ class Agent(BaseAgent):
task_prompt += self.i18n.slice("memory").format(memory=memory)
if self._knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
kc = self.knowledge_config
query_kwargs: Dict[str, Any] = {}
if kc:
query_kwargs["limit"] = kc.results_limit
query_kwargs["score_threshold"] = kc.score_threshold
if kc.metadata_filter:
query_kwargs["filter"] = kc.metadata_filter
agent_knowledge_snippets = self._knowledge.query(
[task.prompt()], **query_kwargs
)
if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
@@ -315,7 +329,16 @@ class Agent(BaseAgent):
task_prompt += agent_knowledge_context
if self.crew:
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
kc = self.knowledge_config
query_kwargs = {}
if kc:
query_kwargs["limit"] = kc.results_limit
query_kwargs["score_threshold"] = kc.score_threshold
if kc.metadata_filter:
query_kwargs["filter"] = kc.metadata_filter
knowledge_snippets = self.crew.query_knowledge(
[task.prompt()], **query_kwargs
)
if knowledge_snippets:
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
if crew_knowledge_context:

View File

@@ -978,9 +978,20 @@ class Crew(BaseModel):
result = self._execute_tasks(self.tasks, start_index, True)
return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
def query_knowledge(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> Union[List[Dict[str, Any]], None]:
if self._knowledge:
return self._knowledge.query(query)
return self._knowledge.query(
query,
limit=limit,
filter=filter,
score_threshold=score_threshold,
)
return None
def copy(self):

View File

@@ -45,7 +45,13 @@ class Knowledge(BaseModel):
source.storage = self.storage
source.add()
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
def query(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -54,6 +60,8 @@ class Knowledge(BaseModel):
results = self.storage.search(
query,
limit,
filter=filter,
score_threshold=score_threshold,
)
return results

View File

@@ -0,0 +1,26 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class KnowledgeConfig(BaseModel):
"""Configuration for knowledge querying behavior.
Attributes:
results_limit: Maximum number of results to return from a knowledge query.
score_threshold: Minimum relevance score for results.
metadata_filter: Metadata filter dict passed to the vector store query.
"""
results_limit: int = Field(
default=3,
description="Maximum number of results to return from a knowledge query.",
)
score_threshold: float = Field(
default=0.35,
description="Minimum relevance score for results.",
)
metadata_filter: Optional[Dict[str, Any]] = Field(
default=None,
description="Metadata filter dict passed to the vector store query.",
)

View File

@@ -46,4 +46,4 @@ class BaseKnowledgeSource(BaseModel, ABC):
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
self.storage.save(self.chunks)
self.storage.save(self.chunks, metadata=self.metadata if self.metadata else None)

View File

@@ -1,11 +1,13 @@
"""Test Knowledge creation and querying functionality."""
from pathlib import Path
from typing import List, Union
from unittest.mock import patch
from typing import Any, Dict, List, Optional, Union
from unittest.mock import MagicMock, call, patch
import pytest
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
@@ -584,3 +586,171 @@ def test_docling_source_with_local_file():
docling_source = CrewDoclingSource(file_paths=[pdf_path])
assert docling_source.file_paths == [pdf_path]
assert docling_source.content is not None
# --- Tests for KnowledgeConfig and metadata_filter support ---
class TestKnowledgeConfig:
"""Tests for the KnowledgeConfig model."""
def test_default_values(self):
config = KnowledgeConfig()
assert config.results_limit == 3
assert config.score_threshold == 0.35
assert config.metadata_filter is None
def test_custom_values(self):
config = KnowledgeConfig(
results_limit=10,
score_threshold=0.5,
metadata_filter={"task": "translation"},
)
assert config.results_limit == 10
assert config.score_threshold == 0.5
assert config.metadata_filter == {"task": "translation"}
def test_partial_override(self):
config = KnowledgeConfig(metadata_filter={"source": "docs"})
assert config.results_limit == 3
assert config.score_threshold == 0.35
assert config.metadata_filter == {"source": "docs"}
class TestKnowledgeQueryForwardsParams:
"""Tests that Knowledge.query() forwards filter and score_threshold to storage."""
def test_query_forwards_metadata_filter(self):
mock_storage = MagicMock()
mock_storage.search.return_value = []
mock_storage.initialize_knowledge_storage.return_value = None
with patch(
"crewai.knowledge.knowledge.KnowledgeStorage",
return_value=mock_storage,
):
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=mock_storage,
)
knowledge.query(
["test query"],
limit=5,
filter={"task": "translation"},
score_threshold=0.6,
)
mock_storage.search.assert_called_once_with(
["test query"],
5,
filter={"task": "translation"},
score_threshold=0.6,
)
def test_query_defaults_without_filter(self):
mock_storage = MagicMock()
mock_storage.search.return_value = []
mock_storage.initialize_knowledge_storage.return_value = None
with patch(
"crewai.knowledge.knowledge.KnowledgeStorage",
return_value=mock_storage,
):
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=mock_storage,
)
knowledge.query(["test query"])
mock_storage.search.assert_called_once_with(
["test query"],
3,
filter=None,
score_threshold=0.35,
)
class TestBaseKnowledgeSourceSavesMetadata:
"""Tests that _save_documents passes self.metadata to storage.save()."""
def test_save_documents_passes_metadata(self):
source = StringKnowledgeSource(
content="Hello world",
metadata={"task": "translation"},
)
source.storage = MagicMock()
source.chunks = ["Hello world"]
source._save_documents()
source.storage.save.assert_called_once_with(
["Hello world"],
metadata={"task": "translation"},
)
def test_save_documents_no_metadata(self):
source = StringKnowledgeSource(content="Hello world")
source.storage = MagicMock()
source.chunks = ["Hello world"]
source._save_documents()
source.storage.save.assert_called_once_with(
["Hello world"],
metadata=None,
)
class TestCrewQueryKnowledgeForwardsParams:
"""Tests that Crew.query_knowledge() forwards filter and score_threshold."""
def _make_crew_stub(self, knowledge=None):
"""Create a minimal stub that satisfies Crew.query_knowledge."""
stub = MagicMock()
stub._knowledge = knowledge
# Bind the real method to the stub
from crewai.crew import Crew
stub.query_knowledge = Crew.query_knowledge.__get__(stub, type(stub))
return stub
def test_crew_query_knowledge_with_filter(self):
mock_knowledge = MagicMock()
mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}]
crew = self._make_crew_stub(knowledge=mock_knowledge)
crew.query_knowledge(
["test query"],
limit=5,
filter={"task": "translation"},
score_threshold=0.6,
)
mock_knowledge.query.assert_called_once_with(
["test query"],
limit=5,
filter={"task": "translation"},
score_threshold=0.6,
)
def test_crew_query_knowledge_defaults(self):
mock_knowledge = MagicMock()
mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}]
crew = self._make_crew_stub(knowledge=mock_knowledge)
crew.query_knowledge(["test query"])
mock_knowledge.query.assert_called_once_with(
["test query"],
limit=3,
filter=None,
score_threshold=0.35,
)
def test_crew_query_knowledge_no_knowledge(self):
crew = self._make_crew_stub(knowledge=None)
result = crew.query_knowledge(["test query"])
assert result is None