Enhance knowledge management in CrewAI

- Added `KnowledgeConfig` class to configure knowledge retrieval parameters such as `limit` and `score_threshold`.
- Updated `Agent` and `Crew` classes to utilize the new knowledge configuration for querying knowledge sources.
- Enhanced documentation to clarify the addition of knowledge sources at both agent and crew levels.
- Introduced new tips in documentation to guide users on knowledge source management and configuration.
This commit is contained in:
lorenzejay
2025-04-17 15:45:09 -07:00
parent 870dffbb89
commit d93e08a3a6
8 changed files with 149 additions and 21 deletions

View File

@@ -42,6 +42,16 @@ CrewAI supports various types of knowledge sources out of the box:
| `collection_name` | **str** | No | Name of the collection where the knowledge will be stored. Used to identify different sets of knowledge. Defaults to "knowledge" if not provided. | | `collection_name` | **str** | No | Name of the collection where the knowledge will be stored. Used to identify different sets of knowledge. Defaults to "knowledge" if not provided. |
| `storage` | **Optional[KnowledgeStorage]** | No | Custom storage configuration for managing how the knowledge is stored and retrieved. If not provided, a default storage will be created. | | `storage` | **Optional[KnowledgeStorage]** | No | Custom storage configuration for managing how the knowledge is stored and retrieved. If not provided, a default storage will be created. |
<Tip>
Unlike retrieval from a vector database using a tool, agents preloaded with knowledge will not need a retrieval persona or task.
Simply add the relevant knowledge sources your agent or crew needs to function.
Knowledge sources can be added at the agent or crew level.
Crew level knowledge sources will be used by **all agents** in the crew.
Agent level knowledge sources will be used by the **specific agent** that is preloaded with the knowledge.
</Tip>
## Quickstart Example ## Quickstart Example
<Tip> <Tip>
@@ -146,6 +156,26 @@ result = crew.kickoff(
) )
``` ```
## Knowledge Configuration
You can configure the knowledge configuration for the crew or agent.
```python Code
from crewai.knowledge.knowledge_config import KnowledgeConfig
knowledge_config = KnowledgeConfig(limit=10, score_threshold=0.5)
agent = Agent(
...
knowledge_config=knowledge_config
)
```
<Tip>
limit: is the number of relevant documents to return. Default is 3.
score_threshold: is the minimum score for a document to be considered relevant. Default is 0.35.
</Tip>
## More Examples ## More Examples
Here are examples of how to use different types of knowledge sources: Here are examples of how to use different types of knowledge sources:

View File

@@ -114,6 +114,14 @@ class Agent(BaseAgent):
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
agent_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
)
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
@@ -229,22 +237,30 @@ class Agent(BaseAgent):
memory = contextual_memory.build_context_for_task(task, context) memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "": if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory) task_prompt += self.i18n.slice("memory").format(memory=memory)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
)
if self.knowledge: if self.knowledge:
agent_knowledge_snippets = self.knowledge.query([task.prompt()]) agent_knowledge_snippets = self.knowledge.query(
[task.prompt()], **knowledge_config
)
if agent_knowledge_snippets: if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context( self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets agent_knowledge_snippets
) )
if agent_knowledge_context: if self.agent_knowledge_context:
task_prompt += agent_knowledge_context task_prompt += self.agent_knowledge_context
if self.crew: if self.crew:
knowledge_snippets = self.crew.query_knowledge([task.prompt()]) knowledge_snippets = self.crew.query_knowledge(
[task.prompt()], **knowledge_config
)
if knowledge_snippets: if knowledge_snippets:
crew_knowledge_context = extract_knowledge_context(knowledge_snippets) self.crew_knowledge_context = extract_knowledge_context(
if crew_knowledge_context: knowledge_snippets
task_prompt += crew_knowledge_context )
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
tools = tools or self.tools or [] tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task) self.create_agent_executor(tools=tools, task=task)

View File

@@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.security.security_config import SecurityConfig from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool from crewai.tools.base_tool import BaseTool, Tool
@@ -155,6 +156,10 @@ class BaseAgent(ABC, BaseModel):
adapted_agent: bool = Field( adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted" default=False, description="Whether the agent is adapted"
) )
knowledge_config: Optional[KnowledgeConfig] = Field(
default=None,
description="Knowledge configuration for the agent such as limits and threshold",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -304,9 +304,7 @@ class Crew(BaseModel):
"""Initialize private memory attributes.""" """Initialize private memory attributes."""
self._external_memory = ( self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally # External memory doesnt support a default value since it was designed to be managed entirely externally
self.external_memory.set_crew(self) self.external_memory.set_crew(self) if self.external_memory else None
if self.external_memory
else None
) )
self._long_term_memory = self.long_term_memory self._long_term_memory = self.long_term_memory
@@ -1136,9 +1134,13 @@ class Crew(BaseModel):
result = self._execute_tasks(self.tasks, start_index, True) result = self._execute_tasks(self.tasks, start_index, True)
return result return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]: def query_knowledge(
self, query: List[str], limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
if self.knowledge: if self.knowledge:
return self.knowledge.query(query) return self.knowledge.query(
query, limit=limit, score_threshold=score_threshold
)
return None return None
def fetch_inputs(self) -> Set[str]: def fetch_inputs(self) -> Set[str]:
@@ -1220,9 +1222,13 @@ class Crew(BaseModel):
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
if self.short_term_memory: if self.short_term_memory:
copied_data["short_term_memory"] = self.short_term_memory.model_copy(deep=True) copied_data["short_term_memory"] = self.short_term_memory.model_copy(
deep=True
)
if self.long_term_memory: if self.long_term_memory:
copied_data["long_term_memory"] = self.long_term_memory.model_copy(deep=True) copied_data["long_term_memory"] = self.long_term_memory.model_copy(
deep=True
)
if self.entity_memory: if self.entity_memory:
copied_data["entity_memory"] = self.entity_memory.model_copy(deep=True) copied_data["entity_memory"] = self.entity_memory.model_copy(deep=True)
if self.external_memory: if self.external_memory:
@@ -1230,7 +1236,6 @@ class Crew(BaseModel):
if self.user_memory: if self.user_memory:
copied_data["user_memory"] = self.user_memory.model_copy(deep=True) copied_data["user_memory"] = self.user_memory.model_copy(deep=True)
copied_data.pop("agents", None) copied_data.pop("agents", None)
copied_data.pop("tasks", None) copied_data.pop("tasks", None)
@@ -1403,7 +1408,10 @@ class Crew(BaseModel):
"short": (getattr(self, "_short_term_memory", None), "short term"), "short": (getattr(self, "_short_term_memory", None), "short term"),
"entity": (getattr(self, "_entity_memory", None), "entity"), "entity": (getattr(self, "_entity_memory", None), "entity"),
"knowledge": (getattr(self, "knowledge", None), "knowledge"), "knowledge": (getattr(self, "knowledge", None), "knowledge"),
"kickoff_outputs": (getattr(self, "_task_output_handler", None), "task output"), "kickoff_outputs": (
getattr(self, "_task_output_handler", None),
"task output",
),
"external": (getattr(self, "_external_memory", None), "external"), "external": (getattr(self, "_external_memory", None), "external"),
} }

View File

@@ -43,7 +43,9 @@ class Knowledge(BaseModel):
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()
self._add_sources() self._add_sources()
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]: def query(
self, query: List[str], limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
""" """
Query across all knowledge sources to find the most relevant information. Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks. Returns the top_k most relevant chunks.
@@ -57,6 +59,7 @@ class Knowledge(BaseModel):
results = self.storage.search( results = self.storage.search(
query, query,
limit, limit,
score_threshold=score_threshold,
) )
return results return results

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class KnowledgeConfig(BaseModel):
limit: int = 3
score_threshold: float = 0.35

View File

@@ -4,7 +4,7 @@ import io
import logging import logging
import os import os
import shutil import shutil
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union
import chromadb import chromadb
import chromadb.errors import chromadb.errors

View File

@@ -10,6 +10,8 @@ from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.agents.parser import CrewAgentParser, OutputParserException from crewai.agents.parser import CrewAgentParser, OutputParserException
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
@@ -259,7 +261,9 @@ def test_cache_hitting():
def handle_tool_end(source, event): def handle_tool_end(source, event):
received_events.append(event) received_events.append(event)
with (patch.object(CacheHandler, "read") as read,): with (
patch.object(CacheHandler, "read") as read,
):
read.return_value = "0" read.return_value = "0"
task = Task( task = Task(
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.", description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
@@ -1611,6 +1615,62 @@ def test_agent_with_knowledge_sources():
assert "red" in result.raw.lower() assert "red" in result.raw.lower()
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig(limit=10, score_threshold=0.5)
with patch.object(Knowledge, "query") as mock_knowledge_query:
agent = Agent(
role="Information Agent",
goal="Provide information based on knowledge sources",
backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source],
knowledge_config=knowledge_config,
)
task = Task(
description="What is Brandon's favorite color?",
expected_output="Brandon's favorite color.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
assert agent.knowledge is not None
mock_knowledge_query.assert_called_once_with(
[task.prompt()],
**knowledge_config.model_dump(),
)
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
with patch.object(Knowledge, "query") as mock_knowledge_query:
agent = Agent(
role="Information Agent",
goal="Provide information based on knowledge sources",
backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source],
knowledge_config=knowledge_config,
)
task = Task(
description="What is Brandon's favorite color?",
expected_output="Brandon's favorite color.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
assert agent.knowledge is not None
mock_knowledge_query.assert_called_once_with(
[task.prompt()],
**knowledge_config.model_dump(),
)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_extensive_role(): def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."