Compare commits

...

13 Commits

Author SHA1 Message Date
Devin AI
fd4081be72 fix: Update remaining embedder reference to embedder_config
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:49:01 +00:00
Devin AI
1508c9810b test: Add embedder_config to contextual memory test
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:44:18 +00:00
Devin AI
2f3e5e0803 test: Add embedder_config to knowledge source tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:42:50 +00:00
Devin AI
615a6795b3 fix: Sort imports in agent.py using ruff
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:38:36 +00:00
Devin AI
633f6973b2 test: Update error message in test_agent_invalid_embedder_config
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:36:12 +00:00
Devin AI
fb4bdad367 fix: Sort imports in test_agent_knowledge.py using ruff
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:34:36 +00:00
Devin AI
c50a88fd40 test: Add comprehensive validation tests for embedder_config
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:32:28 +00:00
Devin AI
626b765b86 fix: Sort imports in test_agent_knowledge.py according to isort standards
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:32:09 +00:00
Devin AI
9cef78a30f feat: Update embedder_config validation to use Pydantic v2 style
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:31:46 +00:00
Devin AI
566ea3ced8 feat: Add validation and improve documentation for embedder_config
- Add validation for embedder_config in Agent class
- Add test cases for invalid embedder configurations
- Improve docstrings with examples and error cases

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:28:34 +00:00
Devin AI
59977a5f7c fix: Sort imports in test_agent_knowledge.py according to standard order
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:27:15 +00:00
Devin AI
bd21eaaf0e fix: Sort imports in test_agent_knowledge.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:25:11 +00:00
Devin AI
f02db1a4f5 fix: Agent-level knowledge sources with non-OpenAI embedders
- Remove OpenAI default from KnowledgeStorage
- Add proper embedder config inheritance from crew to agent
- Improve error messaging for missing embedder config
- Add tests for agent-level knowledge sources

Fixes #2164

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-19 15:23:53 +00:00
8 changed files with 237 additions and 27 deletions

View File

@@ -3,7 +3,7 @@ import shutil
import subprocess import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Union from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator from pydantic import Field, InstanceOf, PrivateAttr, field_validator, model_validator
from crewai.agents import CacheHandler from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -16,7 +16,7 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task from crewai.task import Task
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts from crewai.utilities import Converter, EmbeddingConfigurator, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description from crewai.utilities.converter import generate_model_description
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm
@@ -115,11 +115,48 @@ class Agent(BaseAgent):
default="safe", default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
) )
embedder: Optional[Dict[str, Any]] = Field( embedder_config: Optional[Dict[str, Any]] = Field(
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent. Must include 'provider' and relevant configuration parameters.",
) )
@field_validator("embedder_config")
@classmethod
def validate_embedder_config(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Validate embedder configuration.
Args:
v: The embedder configuration to validate.
Must include 'provider' and 'config' keys.
Example:
{
'provider': 'openai',
'config': {
'api_key': 'your-key',
'model': 'text-embedding-3-small'
}
}
Returns:
The validated embedder configuration.
Raises:
ValueError: If the embedder configuration is invalid.
"""
if v is not None:
if not isinstance(v, dict):
raise ValueError("embedder_config must be a dictionary")
if "provider" not in v:
raise ValueError("embedder_config must contain 'provider' key")
if "config" not in v:
raise ValueError("embedder_config must contain 'config' key")
if v["provider"] not in EmbeddingConfigurator().embedding_functions:
raise ValueError(
f"Unsupported embedding provider: {v['provider']}, "
f"supported providers: {list(EmbeddingConfigurator().embedding_functions.keys())}"
)
return v
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
self._set_knowledge() self._set_knowledge()
@@ -150,9 +187,14 @@ class Agent(BaseAgent):
if isinstance(self.knowledge_sources, list) and all( if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
): ):
# Use agent's embedder config if provided, otherwise use crew's
embedder_config = self.embedder_config
if not embedder_config and self.crew:
embedder_config = self.crew.embedder_config
self.knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder=self.embedder, embedder_config=embedder_config,
collection_name=knowledge_agent_name, collection_name=knowledge_agent_name,
storage=self.knowledge_storage or None, storage=self.knowledge_storage or None,
) )

View File

@@ -138,7 +138,7 @@ class Crew(BaseModel):
default=None, default=None,
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.", description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
) )
embedder: Optional[dict] = Field( embedder_config: Optional[dict] = Field(
default=None, default=None,
description="Configuration for the embedder to be used for the crew.", description="Configuration for the embedder to be used for the crew.",
) )
@@ -268,13 +268,13 @@ class Crew(BaseModel):
if self.short_term_memory if self.short_term_memory
else ShortTermMemory( else ShortTermMemory(
crew=self, crew=self,
embedder_config=self.embedder, embedder_config=self.embedder_config,
) )
) )
self._entity_memory = ( self._entity_memory = (
self.entity_memory self.entity_memory
if self.entity_memory if self.entity_memory
else EntityMemory(crew=self, embedder_config=self.embedder) else EntityMemory(crew=self, embedder_config=self.embedder_config)
) )
if ( if (
self.memory_config and "user_memory" in self.memory_config self.memory_config and "user_memory" in self.memory_config
@@ -308,7 +308,7 @@ class Crew(BaseModel):
): ):
self.knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder=self.embedder, embedder_config=self.embedder_config,
collection_name="crew", collection_name="crew",
) )

View File

@@ -15,29 +15,30 @@ class Knowledge(BaseModel):
Args: Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder_config: Optional[Dict[str, Any]] = None
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder_config: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None collection_name: Optional[str] = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None, embedder_config: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None, storage: Optional[KnowledgeStorage] = None,
**data, **data,
): ):
super().__init__(**data) super().__init__(**data)
self.embedder_config = embedder_config
if storage: if storage:
self.storage = storage self.storage = storage
else: else:
self.storage = KnowledgeStorage( self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name embedder_config=embedder_config, collection_name=collection_name
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()

View File

@@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder_config: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
self.collection_name = collection_name self.collection_name = collection_name
self._set_embedder_config(embedder) self._set_embedder_config(embedder_config)
def search( def search(
self, self,
@@ -179,23 +179,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
raise raise
def _create_default_embedding_function(self): def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import ( raise ValueError(
OpenAIEmbeddingFunction, "No embedder configuration provided. Please provide an embedder configuration "
"either at the crew level or agent level. You can configure embeddings using "
"the 'embedder_config' parameter with providers like 'openai', 'watson', etc. "
"Example: embedder_config={'provider': 'openai', 'config': {'api_key': 'your-key'}}"
) )
return OpenAIEmbeddingFunction( def _set_embedder_config(self, embedder_config: Optional[Dict[str, Any]] = None) -> None:
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. embedder_config: Must include 'provider' and relevant configuration parameters.
If None or empty, defaults to the default embedding function. For example:
{
'provider': 'openai',
'config': {
'api_key': 'your-key',
'model': 'text-embedding-3-small'
}
}
Raises:
ValueError: If no configuration is provided or if the configuration is invalid.
""" """
self.embedder = ( self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder) EmbeddingConfigurator().configure_embedder(embedder_config)
if embedder if embedder_config
else self._create_default_embedding_function() else self._create_default_embedding_function()
) )

View File

@@ -1596,6 +1596,12 @@ def test_agent_with_knowledge_sources():
backstory="You have access to specific knowledge sources.", backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"), llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source], knowledge_sources=[string_source],
embedder_config={
"provider": "openai",
"config": {
"api_key": "fake-api-key"
}
}
) )
# Create a task that requires the agent to use the knowledge # Create a task that requires the agent to use the knowledge
@@ -1631,6 +1637,12 @@ def test_agent_with_knowledge_sources_works_with_copy():
backstory="You have access to specific knowledge sources.", backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"), llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source], knowledge_sources=[string_source],
embedder_config={
"provider": "openai",
"config": {
"api_key": "fake-api-key"
}
}
) )
with patch( with patch(

View File

@@ -2322,6 +2322,12 @@ def test_using_contextual_memory():
agents=[math_researcher], agents=[math_researcher],
tasks=[task1], tasks=[task1],
memory=True, memory=True,
embedder_config={
"provider": "openai",
"config": {
"api_key": "fake-api-key"
}
}
) )
with patch.object(ContextualMemory, "build_context_for_task") as contextual_mem: with patch.object(ContextualMemory, "build_context_for_task") as contextual_mem:

View File

@@ -0,0 +1,134 @@
from unittest.mock import MagicMock, patch
import pytest
from chromadb.api.types import EmbeddingFunction
from crewai import Agent, Crew, Task
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.process import Process
class MockEmbeddingFunction(EmbeddingFunction):
def __call__(self, texts):
return [[0.0] * 1536 for _ in texts]
@pytest.fixture(autouse=True)
def mock_vector_db():
"""Mock vector database operations."""
with patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") as mock, \
patch("chromadb.PersistentClient") as mock_chroma:
# Mock ChromaDB client and collection
mock_collection = MagicMock()
mock_collection.query.return_value = {
"ids": [["1"]],
"distances": [[0.1]],
"metadatas": [[{"source": "test"}]],
"documents": [["Test content"]]
}
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
# Mock the query method to return a predefined response
instance = mock.return_value
instance.query.return_value = [
{
"context": "Test content",
"score": 0.9,
}
]
instance.reset.return_value = None
yield instance
def test_agent_invalid_embedder_config():
"""Test that an invalid embedder configuration raises a ValueError."""
with pytest.raises(ValueError, match="Input should be a valid dictionary"):
Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
embedder_config="invalid"
)
with pytest.raises(ValueError, match="embedder_config must contain 'provider' key"):
Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
embedder_config={"invalid": "config"}
)
with pytest.raises(ValueError, match="embedder_config must contain 'config' key"):
Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
embedder_config={"provider": "custom"}
)
with pytest.raises(ValueError, match="Unsupported embedding provider"):
Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
embedder_config={"provider": "invalid", "config": {}}
)
def test_agent_knowledge_with_custom_embedder(mock_vector_db):
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
embedder_config={
"provider": "custom",
"config": {
"embedder": MockEmbeddingFunction()
}
}
)
assert agent.knowledge is not None
assert agent.knowledge.storage.embedder is not None
def test_agent_inherits_crew_embedder(mock_vector_db):
test_agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory"
)
test_task = Task(
description="test task",
expected_output="test output",
agent=test_agent
)
crew = Crew(
agents=[test_agent],
tasks=[test_task],
process=Process.sequential,
embedder_config={
"provider": "custom",
"config": {
"embedder": MockEmbeddingFunction()
}
}
)
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")],
crew=crew
)
assert agent.knowledge is not None
assert agent.knowledge.storage.embedder is not None
def test_agent_knowledge_without_embedder_raises_error(mock_vector_db):
with pytest.raises(ValueError, match="No embedder configuration provided"):
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
knowledge_sources=[StringKnowledgeSource(content="test content")]
)

View File

@@ -45,7 +45,13 @@ def test_knowledge_included_in_planning(mock_chroma):
StringKnowledgeSource( StringKnowledgeSource(
content="AI systems require careful training and validation." content="AI systems require careful training and validation."
) )
] ],
embedder_config={
"provider": "openai",
"config": {
"api_key": "fake-api-key"
}
}
) )
# Create a task for the agent # Create a task for the agent