mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-14 13:38:12 +00:00
Compare commits
1 Commits
bugfix/asy
...
devin/1778
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f72839fea |
@@ -10,6 +10,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.llm import LLM
|
||||
@@ -131,6 +132,10 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
||||
default=None,
|
||||
description="Configuration for knowledge querying (results_limit, score_threshold, metadata_filter).",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
@@ -306,7 +311,16 @@ class Agent(BaseAgent):
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
if self._knowledge:
|
||||
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
|
||||
kc = self.knowledge_config
|
||||
query_kwargs: Dict[str, Any] = {}
|
||||
if kc:
|
||||
query_kwargs["limit"] = kc.results_limit
|
||||
query_kwargs["score_threshold"] = kc.score_threshold
|
||||
if kc.metadata_filter:
|
||||
query_kwargs["filter"] = kc.metadata_filter
|
||||
agent_knowledge_snippets = self._knowledge.query(
|
||||
[task.prompt()], **query_kwargs
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
@@ -315,7 +329,16 @@ class Agent(BaseAgent):
|
||||
task_prompt += agent_knowledge_context
|
||||
|
||||
if self.crew:
|
||||
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
|
||||
kc = self.knowledge_config
|
||||
query_kwargs = {}
|
||||
if kc:
|
||||
query_kwargs["limit"] = kc.results_limit
|
||||
query_kwargs["score_threshold"] = kc.score_threshold
|
||||
if kc.metadata_filter:
|
||||
query_kwargs["filter"] = kc.metadata_filter
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[task.prompt()], **query_kwargs
|
||||
)
|
||||
if knowledge_snippets:
|
||||
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
|
||||
if crew_knowledge_context:
|
||||
|
||||
@@ -978,9 +978,20 @@ class Crew(BaseModel):
|
||||
result = self._execute_tasks(self.tasks, start_index, True)
|
||||
return result
|
||||
|
||||
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
|
||||
def query_knowledge(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
if self._knowledge:
|
||||
return self._knowledge.query(query)
|
||||
return self._knowledge.query(
|
||||
query,
|
||||
limit=limit,
|
||||
filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return None
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -45,7 +45,13 @@ class Knowledge(BaseModel):
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
|
||||
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
|
||||
def query(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -54,6 +60,8 @@ class Knowledge(BaseModel):
|
||||
results = self.storage.search(
|
||||
query,
|
||||
limit,
|
||||
filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
26
src/crewai/knowledge/knowledge_config.py
Normal file
26
src/crewai/knowledge/knowledge_config.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KnowledgeConfig(BaseModel):
|
||||
"""Configuration for knowledge querying behavior.
|
||||
|
||||
Attributes:
|
||||
results_limit: Maximum number of results to return from a knowledge query.
|
||||
score_threshold: Minimum relevance score for results.
|
||||
metadata_filter: Metadata filter dict passed to the vector store query.
|
||||
"""
|
||||
|
||||
results_limit: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of results to return from a knowledge query.",
|
||||
)
|
||||
score_threshold: float = Field(
|
||||
default=0.35,
|
||||
description="Minimum relevance score for results.",
|
||||
)
|
||||
metadata_filter: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Metadata filter dict passed to the vector store query.",
|
||||
)
|
||||
@@ -46,4 +46,4 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Save the documents to the storage.
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
"""
|
||||
self.storage.save(self.chunks)
|
||||
self.storage.save(self.chunks, metadata=self.metadata if self.metadata else None)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Test Knowledge creation and querying functionality."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||
@@ -584,3 +586,171 @@ def test_docling_source_with_local_file():
|
||||
docling_source = CrewDoclingSource(file_paths=[pdf_path])
|
||||
assert docling_source.file_paths == [pdf_path]
|
||||
assert docling_source.content is not None
|
||||
|
||||
|
||||
# --- Tests for KnowledgeConfig and metadata_filter support ---
|
||||
|
||||
|
||||
class TestKnowledgeConfig:
|
||||
"""Tests for the KnowledgeConfig model."""
|
||||
|
||||
def test_default_values(self):
|
||||
config = KnowledgeConfig()
|
||||
assert config.results_limit == 3
|
||||
assert config.score_threshold == 0.35
|
||||
assert config.metadata_filter is None
|
||||
|
||||
def test_custom_values(self):
|
||||
config = KnowledgeConfig(
|
||||
results_limit=10,
|
||||
score_threshold=0.5,
|
||||
metadata_filter={"task": "translation"},
|
||||
)
|
||||
assert config.results_limit == 10
|
||||
assert config.score_threshold == 0.5
|
||||
assert config.metadata_filter == {"task": "translation"}
|
||||
|
||||
def test_partial_override(self):
|
||||
config = KnowledgeConfig(metadata_filter={"source": "docs"})
|
||||
assert config.results_limit == 3
|
||||
assert config.score_threshold == 0.35
|
||||
assert config.metadata_filter == {"source": "docs"}
|
||||
|
||||
|
||||
class TestKnowledgeQueryForwardsParams:
|
||||
"""Tests that Knowledge.query() forwards filter and score_threshold to storage."""
|
||||
|
||||
def test_query_forwards_metadata_filter(self):
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.search.return_value = []
|
||||
mock_storage.initialize_knowledge_storage.return_value = None
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.knowledge.KnowledgeStorage",
|
||||
return_value=mock_storage,
|
||||
):
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
knowledge.query(
|
||||
["test query"],
|
||||
limit=5,
|
||||
filter={"task": "translation"},
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
mock_storage.search.assert_called_once_with(
|
||||
["test query"],
|
||||
5,
|
||||
filter={"task": "translation"},
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
def test_query_defaults_without_filter(self):
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.search.return_value = []
|
||||
mock_storage.initialize_knowledge_storage.return_value = None
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.knowledge.KnowledgeStorage",
|
||||
return_value=mock_storage,
|
||||
):
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
knowledge.query(["test query"])
|
||||
|
||||
mock_storage.search.assert_called_once_with(
|
||||
["test query"],
|
||||
3,
|
||||
filter=None,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
|
||||
class TestBaseKnowledgeSourceSavesMetadata:
|
||||
"""Tests that _save_documents passes self.metadata to storage.save()."""
|
||||
|
||||
def test_save_documents_passes_metadata(self):
|
||||
source = StringKnowledgeSource(
|
||||
content="Hello world",
|
||||
metadata={"task": "translation"},
|
||||
)
|
||||
source.storage = MagicMock()
|
||||
source.chunks = ["Hello world"]
|
||||
source._save_documents()
|
||||
|
||||
source.storage.save.assert_called_once_with(
|
||||
["Hello world"],
|
||||
metadata={"task": "translation"},
|
||||
)
|
||||
|
||||
def test_save_documents_no_metadata(self):
|
||||
source = StringKnowledgeSource(content="Hello world")
|
||||
source.storage = MagicMock()
|
||||
source.chunks = ["Hello world"]
|
||||
source._save_documents()
|
||||
|
||||
source.storage.save.assert_called_once_with(
|
||||
["Hello world"],
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
|
||||
class TestCrewQueryKnowledgeForwardsParams:
|
||||
"""Tests that Crew.query_knowledge() forwards filter and score_threshold."""
|
||||
|
||||
def _make_crew_stub(self, knowledge=None):
|
||||
"""Create a minimal stub that satisfies Crew.query_knowledge."""
|
||||
stub = MagicMock()
|
||||
stub._knowledge = knowledge
|
||||
# Bind the real method to the stub
|
||||
from crewai.crew import Crew
|
||||
|
||||
stub.query_knowledge = Crew.query_knowledge.__get__(stub, type(stub))
|
||||
return stub
|
||||
|
||||
def test_crew_query_knowledge_with_filter(self):
|
||||
mock_knowledge = MagicMock()
|
||||
mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}]
|
||||
crew = self._make_crew_stub(knowledge=mock_knowledge)
|
||||
|
||||
crew.query_knowledge(
|
||||
["test query"],
|
||||
limit=5,
|
||||
filter={"task": "translation"},
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
mock_knowledge.query.assert_called_once_with(
|
||||
["test query"],
|
||||
limit=5,
|
||||
filter={"task": "translation"},
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
def test_crew_query_knowledge_defaults(self):
|
||||
mock_knowledge = MagicMock()
|
||||
mock_knowledge.query.return_value = [{"context": "test", "score": 0.9}]
|
||||
crew = self._make_crew_stub(knowledge=mock_knowledge)
|
||||
|
||||
crew.query_knowledge(["test query"])
|
||||
|
||||
mock_knowledge.query.assert_called_once_with(
|
||||
["test query"],
|
||||
limit=3,
|
||||
filter=None,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
def test_crew_query_knowledge_no_knowledge(self):
|
||||
crew = self._make_crew_stub(knowledge=None)
|
||||
|
||||
result = crew.query_knowledge(["test query"])
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user