Enhance knowledge management in CrewAI (#2637)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled

* Enhance knowledge management in CrewAI

- Added `KnowledgeConfig` class to configure knowledge retrieval parameters such as `limit` and `score_threshold`.
- Updated `Agent` and `Crew` classes to utilize the new knowledge configuration for querying knowledge sources.
- Enhanced documentation to clarify the addition of knowledge sources at both agent and crew levels.
- Introduced new tips in documentation to guide users on knowledge source management and configuration.

* Refactor knowledge configuration parameters in CrewAI

- Renamed `limit` to `results_limit` in `KnowledgeConfig`, `query_knowledge`, and `query` methods for consistency and clarity.
- Updated related documentation to reflect the new parameter name, ensuring users understand the configuration options for knowledge retrieval.

* Refactor agent tests to utilize mock knowledge storage

- Updated test cases in `agent_test.py` to use `KnowledgeStorage` for mocking knowledge sources, enhancing test reliability and clarity.
- Renamed `limit` to `results_limit` in `KnowledgeConfig` for consistency with recent changes.
- Ensured that knowledge queries are properly mocked to return expected results during tests.

* Add VCR support for agent tests with query limits and score thresholds

- Introduced `@pytest.mark.vcr` decorator in `agent_test.py` for tests involving knowledge sources, ensuring consistent recording of HTTP interactions.
- Added new YAML cassette files for `test_agent_with_knowledge_sources_with_query_limit_and_score_threshold` and `test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default`, capturing the expected API responses for these tests.
- Enhanced test reliability by utilizing VCR to manage external API calls during testing.

* Update documentation to format parameter names in code style

- Changed the formatting of `results_limit` and `score_threshold` in the documentation to use code style for better clarity and emphasis.
- Ensured consistency in documentation presentation to enhance user understanding of configuration options.

* Enhance KnowledgeConfig with field descriptions

- Updated `results_limit` and `score_threshold` in `KnowledgeConfig` to use Pydantic's `Field` for improved documentation and clarity.
- Added descriptions to both parameters to provide better context for their usage in knowledge retrieval configuration.

* docstrings added
This commit is contained in:
Lorenze Jay
2025-04-18 18:33:04 -07:00
committed by GitHub
parent 371f19f3cd
commit 311a078ca6
10 changed files with 836 additions and 22 deletions

View File

@@ -114,6 +114,14 @@ class Agent(BaseAgent):
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
)
@model_validator(mode="after")
def post_init_setup(self):
@@ -234,22 +242,30 @@ class Agent(BaseAgent):
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
)
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
agent_knowledge_snippets = self.knowledge.query(
[task.prompt()], **knowledge_config
)
if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context(
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent_knowledge_context:
task_prompt += agent_knowledge_context
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
if self.crew:
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
knowledge_snippets = self.crew.query_knowledge(
[task.prompt()], **knowledge_config
)
if knowledge_snippets:
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
if crew_knowledge_context:
task_prompt += crew_knowledge_context
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)

View File

@@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
@@ -155,6 +156,10 @@ class BaseAgent(ABC, BaseModel):
adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted"
)
knowledge_config: Optional[KnowledgeConfig] = Field(
default=None,
description="Knowledge configuration for the agent such as limits and threshold",
)
@model_validator(mode="before")
@classmethod

View File

@@ -304,9 +304,7 @@ class Crew(BaseModel):
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
self.external_memory.set_crew(self)
if self.external_memory
else None
self.external_memory.set_crew(self) if self.external_memory else None
)
self._long_term_memory = self.long_term_memory
@@ -1136,9 +1134,13 @@ class Crew(BaseModel):
result = self._execute_tasks(self.tasks, start_index, True)
return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
if self.knowledge:
return self.knowledge.query(query)
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
@@ -1220,9 +1222,13 @@ class Crew(BaseModel):
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
if self.short_term_memory:
copied_data["short_term_memory"] = self.short_term_memory.model_copy(deep=True)
copied_data["short_term_memory"] = self.short_term_memory.model_copy(
deep=True
)
if self.long_term_memory:
copied_data["long_term_memory"] = self.long_term_memory.model_copy(deep=True)
copied_data["long_term_memory"] = self.long_term_memory.model_copy(
deep=True
)
if self.entity_memory:
copied_data["entity_memory"] = self.entity_memory.model_copy(deep=True)
if self.external_memory:
@@ -1230,7 +1236,6 @@ class Crew(BaseModel):
if self.user_memory:
copied_data["user_memory"] = self.user_memory.model_copy(deep=True)
copied_data.pop("agents", None)
copied_data.pop("tasks", None)
@@ -1403,7 +1408,10 @@ class Crew(BaseModel):
"short": (getattr(self, "_short_term_memory", None), "short term"),
"entity": (getattr(self, "_entity_memory", None), "entity"),
"knowledge": (getattr(self, "knowledge", None), "knowledge"),
"kickoff_outputs": (getattr(self, "_task_output_handler", None), "task output"),
"kickoff_outputs": (
getattr(self, "_task_output_handler", None),
"task output",
),
"external": (getattr(self, "_external_memory", None), "external"),
}

View File

@@ -43,7 +43,9 @@ class Knowledge(BaseModel):
self.storage.initialize_knowledge_storage()
self._add_sources()
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -56,7 +58,8 @@ class Knowledge(BaseModel):
results = self.storage.search(
query,
limit,
limit=results_limit,
score_threshold=score_threshold,
)
return results

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel, Field
class KnowledgeConfig(BaseModel):
"""Configuration for knowledge retrieval.
Args:
results_limit (int): The number of relevant documents to return.
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")
score_threshold: float = Field(
default=0.35,
description="The minimum score for a result to be considered relevant",
)

View File

@@ -4,7 +4,7 @@ import io
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors