mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
fix: Agent-level knowledge sources with non-OpenAI embedders
- Remove OpenAI default from KnowledgeStorage - Add proper embedder config inheritance from crew to agent - Improve error messaging for missing embedder config - Add tests for agent-level knowledge sources Fixes #2164 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
94
tests/test_agent_knowledge.py
Normal file
94
tests/test_agent_knowledge.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.process import Process
|
||||
|
||||
class MockEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, texts):
|
||||
return [[0.0] * 1536 for _ in texts]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vector_db():
|
||||
"""Mock vector database operations."""
|
||||
with patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") as mock, \
|
||||
patch("chromadb.PersistentClient") as mock_chroma:
|
||||
# Mock ChromaDB client and collection
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["1"]],
|
||||
"distances": [[0.1]],
|
||||
"metadatas": [[{"source": "test"}]],
|
||||
"documents": [["Test content"]]
|
||||
}
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
# Mock the query method to return a predefined response
|
||||
instance = mock.return_value
|
||||
instance.query.return_value = [
|
||||
{
|
||||
"context": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
instance.reset.return_value = None
|
||||
yield instance
|
||||
|
||||
def test_agent_knowledge_with_custom_embedder(mock_vector_db):
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedder": MockEmbeddingFunction()
|
||||
}
|
||||
}
|
||||
)
|
||||
assert agent.knowledge is not None
|
||||
assert agent.knowledge.storage.embedder is not None
|
||||
|
||||
def test_agent_inherits_crew_embedder(mock_vector_db):
|
||||
test_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
test_task = Task(
|
||||
description="test task",
|
||||
expected_output="test output",
|
||||
agent=test_agent
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[test_task],
|
||||
process=Process.sequential,
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedder": MockEmbeddingFunction()
|
||||
}
|
||||
}
|
||||
)
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
crew=crew
|
||||
)
|
||||
assert agent.knowledge is not None
|
||||
assert agent.knowledge.storage.embedder is not None
|
||||
|
||||
def test_agent_knowledge_without_embedder_raises_error(mock_vector_db):
|
||||
with pytest.raises(ValueError, match="No embedder configuration provided"):
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")]
|
||||
)
|
||||
Reference in New Issue
Block a user