mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
13 Commits
1.2.1
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd4081be72 | ||
|
|
1508c9810b | ||
|
|
2f3e5e0803 | ||
|
|
615a6795b3 | ||
|
|
633f6973b2 | ||
|
|
fb4bdad367 | ||
|
|
c50a88fd40 | ||
|
|
626b765b86 | ||
|
|
9cef78a30f | ||
|
|
566ea3ced8 | ||
|
|
59977a5f7c | ||
|
|
bd21eaaf0e | ||
|
|
f02db1a4f5 |
@@ -3,7 +3,7 @@ import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, field_validator, model_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -16,7 +16,7 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities import Converter, EmbeddingConfigurator, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
@@ -115,11 +115,48 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
description="Embedder configuration for the agent. Must include 'provider' and relevant configuration parameters.",
|
||||
)
|
||||
|
||||
@field_validator("embedder_config")
|
||||
@classmethod
|
||||
def validate_embedder_config(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Validate embedder configuration.
|
||||
|
||||
Args:
|
||||
v: The embedder configuration to validate.
|
||||
Must include 'provider' and 'config' keys.
|
||||
Example:
|
||||
{
|
||||
'provider': 'openai',
|
||||
'config': {
|
||||
'api_key': 'your-key',
|
||||
'model': 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
The validated embedder configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If the embedder configuration is invalid.
|
||||
"""
|
||||
if v is not None:
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("embedder_config must be a dictionary")
|
||||
if "provider" not in v:
|
||||
raise ValueError("embedder_config must contain 'provider' key")
|
||||
if "config" not in v:
|
||||
raise ValueError("embedder_config must contain 'config' key")
|
||||
if v["provider"] not in EmbeddingConfigurator().embedding_functions:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {v['provider']}, "
|
||||
f"supported providers: {list(EmbeddingConfigurator().embedding_functions.keys())}"
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
@@ -150,9 +187,14 @@ class Agent(BaseAgent):
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
# Use agent's embedder config if provided, otherwise use crew's
|
||||
embedder_config = self.embedder_config
|
||||
if not embedder_config and self.crew:
|
||||
embedder_config = self.crew.embedder_config
|
||||
|
||||
self.knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder=self.embedder,
|
||||
embedder_config=embedder_config,
|
||||
collection_name=knowledge_agent_name,
|
||||
storage=self.knowledge_storage or None,
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
|
||||
)
|
||||
embedder: Optional[dict] = Field(
|
||||
embedder_config: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -268,13 +268,13 @@ class Crew(BaseModel):
|
||||
if self.short_term_memory
|
||||
else ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
embedder_config=self.embedder_config,
|
||||
)
|
||||
)
|
||||
self._entity_memory = (
|
||||
self.entity_memory
|
||||
if self.entity_memory
|
||||
else EntityMemory(crew=self, embedder_config=self.embedder)
|
||||
else EntityMemory(crew=self, embedder_config=self.embedder_config)
|
||||
)
|
||||
if (
|
||||
self.memory_config and "user_memory" in self.memory_config
|
||||
@@ -308,7 +308,7 @@ class Crew(BaseModel):
|
||||
):
|
||||
self.knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder=self.embedder,
|
||||
embedder_config=self.embedder_config,
|
||||
collection_name="crew",
|
||||
)
|
||||
|
||||
|
||||
@@ -15,29 +15,30 @@ class Knowledge(BaseModel):
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
self.embedder_config = embedder_config
|
||||
if storage:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
embedder_config=embedder_config, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
@@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder)
|
||||
self._set_embedder_config(embedder_config)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -179,23 +179,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
raise
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
raise ValueError(
|
||||
"No embedder configuration provided. Please provide an embedder configuration "
|
||||
"either at the crew level or agent level. You can configure embeddings using "
|
||||
"the 'embedder_config' parameter with providers like 'openai', 'watson', etc. "
|
||||
"Example: embedder_config={'provider': 'openai', 'config': {'api_key': 'your-key'}}"
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder_config: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
embedder_config: Must include 'provider' and relevant configuration parameters.
|
||||
For example:
|
||||
{
|
||||
'provider': 'openai',
|
||||
'config': {
|
||||
'api_key': 'your-key',
|
||||
'model': 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration is provided or if the configuration is invalid.
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
EmbeddingConfigurator().configure_embedder(embedder_config)
|
||||
if embedder_config
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -1596,6 +1596,12 @@ def test_agent_with_knowledge_sources():
|
||||
backstory="You have access to specific knowledge sources.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
embedder_config={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "fake-api-key"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Create a task that requires the agent to use the knowledge
|
||||
@@ -1631,6 +1637,12 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
backstory="You have access to specific knowledge sources.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
embedder_config={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "fake-api-key"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
|
||||
@@ -2322,6 +2322,12 @@ def test_using_contextual_memory():
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
embedder_config={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "fake-api-key"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(ContextualMemory, "build_context_for_task") as contextual_mem:
|
||||
|
||||
134
tests/test_agent_knowledge.py
Normal file
134
tests/test_agent_knowledge.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
class MockEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, texts):
|
||||
return [[0.0] * 1536 for _ in texts]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vector_db():
|
||||
"""Mock vector database operations."""
|
||||
with patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") as mock, \
|
||||
patch("chromadb.PersistentClient") as mock_chroma:
|
||||
# Mock ChromaDB client and collection
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["1"]],
|
||||
"distances": [[0.1]],
|
||||
"metadatas": [[{"source": "test"}]],
|
||||
"documents": [["Test content"]]
|
||||
}
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
# Mock the query method to return a predefined response
|
||||
instance = mock.return_value
|
||||
instance.query.return_value = [
|
||||
{
|
||||
"context": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
instance.reset.return_value = None
|
||||
yield instance
|
||||
|
||||
def test_agent_invalid_embedder_config():
|
||||
"""Test that an invalid embedder configuration raises a ValueError."""
|
||||
with pytest.raises(ValueError, match="Input should be a valid dictionary"):
|
||||
Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config="invalid"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="embedder_config must contain 'provider' key"):
|
||||
Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config={"invalid": "config"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="embedder_config must contain 'config' key"):
|
||||
Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config={"provider": "custom"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported embedding provider"):
|
||||
Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config={"provider": "invalid", "config": {}}
|
||||
)
|
||||
|
||||
def test_agent_knowledge_with_custom_embedder(mock_vector_db):
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedder": MockEmbeddingFunction()
|
||||
}
|
||||
}
|
||||
)
|
||||
assert agent.knowledge is not None
|
||||
assert agent.knowledge.storage.embedder is not None
|
||||
|
||||
def test_agent_inherits_crew_embedder(mock_vector_db):
|
||||
test_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
test_task = Task(
|
||||
description="test task",
|
||||
expected_output="test output",
|
||||
agent=test_agent
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[test_task],
|
||||
process=Process.sequential,
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedder": MockEmbeddingFunction()
|
||||
}
|
||||
}
|
||||
)
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")],
|
||||
crew=crew
|
||||
)
|
||||
assert agent.knowledge is not None
|
||||
assert agent.knowledge.storage.embedder is not None
|
||||
|
||||
def test_agent_knowledge_without_embedder_raises_error(mock_vector_db):
|
||||
with pytest.raises(ValueError, match="No embedder configuration provided"):
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
knowledge_sources=[StringKnowledgeSource(content="test content")]
|
||||
)
|
||||
@@ -45,7 +45,13 @@ def test_knowledge_included_in_planning(mock_chroma):
|
||||
StringKnowledgeSource(
|
||||
content="AI systems require careful training and validation."
|
||||
)
|
||||
]
|
||||
],
|
||||
embedder_config={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "fake-api-key"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
|
||||
Reference in New Issue
Block a user