mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: Agent-level knowledge sources with non-OpenAI embedders
- Remove OpenAI default from KnowledgeStorage - Add proper embedder config inheritance from crew to agent - Improve error messaging for missing embedder config - Add tests for agent-level knowledge sources Fixes #2164 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -115,7 +115,7 @@ class Agent(BaseAgent):
|
|||||||
default="safe",
|
default="safe",
|
||||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||||
)
|
)
|
||||||
embedder: Optional[Dict[str, Any]] = Field(
|
embedder_config: Optional[Dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Embedder configuration for the agent.",
|
description="Embedder configuration for the agent.",
|
||||||
)
|
)
|
||||||
@@ -150,9 +150,14 @@ class Agent(BaseAgent):
|
|||||||
if isinstance(self.knowledge_sources, list) and all(
|
if isinstance(self.knowledge_sources, list) and all(
|
||||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||||
):
|
):
|
||||||
|
# Use agent's embedder config if provided, otherwise use crew's
|
||||||
|
embedder_config = self.embedder_config
|
||||||
|
if not embedder_config and self.crew:
|
||||||
|
embedder_config = self.crew.embedder_config
|
||||||
|
|
||||||
self.knowledge = Knowledge(
|
self.knowledge = Knowledge(
|
||||||
sources=self.knowledge_sources,
|
sources=self.knowledge_sources,
|
||||||
embedder=self.embedder,
|
embedder_config=embedder_config,
|
||||||
collection_name=knowledge_agent_name,
|
collection_name=knowledge_agent_name,
|
||||||
storage=self.knowledge_storage or None,
|
storage=self.knowledge_storage or None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
|
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
|
||||||
)
|
)
|
||||||
embedder: Optional[dict] = Field(
|
embedder_config: Optional[dict] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Configuration for the embedder to be used for the crew.",
|
description="Configuration for the embedder to be used for the crew.",
|
||||||
)
|
)
|
||||||
@@ -308,7 +308,7 @@ class Crew(BaseModel):
|
|||||||
):
|
):
|
||||||
self.knowledge = Knowledge(
|
self.knowledge = Knowledge(
|
||||||
sources=self.knowledge_sources,
|
sources=self.knowledge_sources,
|
||||||
embedder=self.embedder,
|
embedder_config=self.embedder_config,
|
||||||
collection_name="crew",
|
collection_name="crew",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,29 +15,30 @@ class Knowledge(BaseModel):
|
|||||||
Args:
|
Args:
|
||||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
embedder: Optional[Dict[str, Any]] = None
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
embedder: Optional[Dict[str, Any]] = None
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
collection_name: Optional[str] = None
|
collection_name: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
sources: List[BaseKnowledgeSource],
|
sources: List[BaseKnowledgeSource],
|
||||||
embedder: Optional[Dict[str, Any]] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
storage: Optional[KnowledgeStorage] = None,
|
storage: Optional[KnowledgeStorage] = None,
|
||||||
**data,
|
**data,
|
||||||
):
|
):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
self.embedder_config = embedder_config
|
||||||
if storage:
|
if storage:
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
else:
|
else:
|
||||||
self.storage = KnowledgeStorage(
|
self.storage = KnowledgeStorage(
|
||||||
embedder=embedder, collection_name=collection_name
|
embedder_config=embedder_config, collection_name=collection_name
|
||||||
)
|
)
|
||||||
self.sources = sources
|
self.sources = sources
|
||||||
self.storage.initialize_knowledge_storage()
|
self.storage.initialize_knowledge_storage()
|
||||||
|
|||||||
@@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedder: Optional[Dict[str, Any]] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self._set_embedder_config(embedder)
|
self._set_embedder_config(embedder_config)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -179,15 +179,14 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_default_embedding_function(self):
|
def _create_default_embedding_function(self):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
raise ValueError(
|
||||||
OpenAIEmbeddingFunction,
|
"No embedder configuration provided. Please provide an embedder configuration "
|
||||||
|
"either at the crew level or agent level. You can configure embeddings using "
|
||||||
|
"the 'embedder_config' parameter with providers like 'openai', 'watson', etc. "
|
||||||
|
"Example: embedder_config={'provider': 'openai', 'config': {'api_key': 'your-key'}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
def _set_embedder_config(self, embedder_config: Optional[Dict[str, Any]] = None) -> None:
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
|
||||||
"""Set the embedding configuration for the knowledge storage.
|
"""Set the embedding configuration for the knowledge storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -195,7 +194,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
If None or empty, defaults to the default embedding function.
|
If None or empty, defaults to the default embedding function.
|
||||||
"""
|
"""
|
||||||
self.embedder = (
|
self.embedder = (
|
||||||
EmbeddingConfigurator().configure_embedder(embedder)
|
EmbeddingConfigurator().configure_embedder(embedder_config)
|
||||||
if embedder
|
if embedder_config
|
||||||
else self._create_default_embedding_function()
|
else self._create_default_embedding_function()
|
||||||
)
|
)
|
||||||
|
|||||||
94
tests/test_agent_knowledge.py
Normal file
94
tests/test_agent_knowledge.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from chromadb.api.types import EmbeddingFunction
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||||
|
from crewai.process import Process
|
||||||
|
|
||||||
|
class MockEmbeddingFunction(EmbeddingFunction):
|
||||||
|
def __call__(self, texts):
|
||||||
|
return [[0.0] * 1536 for _ in texts]
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_vector_db():
|
||||||
|
"""Mock vector database operations."""
|
||||||
|
with patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") as mock, \
|
||||||
|
patch("chromadb.PersistentClient") as mock_chroma:
|
||||||
|
# Mock ChromaDB client and collection
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_collection.query.return_value = {
|
||||||
|
"ids": [["1"]],
|
||||||
|
"distances": [[0.1]],
|
||||||
|
"metadatas": [[{"source": "test"}]],
|
||||||
|
"documents": [["Test content"]]
|
||||||
|
}
|
||||||
|
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
# Mock the query method to return a predefined response
|
||||||
|
instance = mock.return_value
|
||||||
|
instance.query.return_value = [
|
||||||
|
{
|
||||||
|
"context": "Test content",
|
||||||
|
"score": 0.9,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
instance.reset.return_value = None
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
def test_agent_knowledge_with_custom_embedder(mock_vector_db):
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||||
|
embedder_config={
|
||||||
|
"provider": "custom",
|
||||||
|
"config": {
|
||||||
|
"embedder": MockEmbeddingFunction()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert agent.knowledge is not None
|
||||||
|
assert agent.knowledge.storage.embedder is not None
|
||||||
|
|
||||||
|
def test_agent_inherits_crew_embedder(mock_vector_db):
|
||||||
|
test_agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory"
|
||||||
|
)
|
||||||
|
test_task = Task(
|
||||||
|
description="test task",
|
||||||
|
expected_output="test output",
|
||||||
|
agent=test_agent
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[test_task],
|
||||||
|
process=Process.sequential,
|
||||||
|
embedder_config={
|
||||||
|
"provider": "custom",
|
||||||
|
"config": {
|
||||||
|
"embedder": MockEmbeddingFunction()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||||
|
crew=crew
|
||||||
|
)
|
||||||
|
assert agent.knowledge is not None
|
||||||
|
assert agent.knowledge.storage.embedder is not None
|
||||||
|
|
||||||
|
def test_agent_knowledge_without_embedder_raises_error(mock_vector_db):
|
||||||
|
with pytest.raises(ValueError, match="No embedder configuration provided"):
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
knowledge_sources=[StringKnowledgeSource(content="test content")]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user