From 707c50b83329ea8e07df80ad6d717f846161aac5 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Tue, 26 Nov 2024 11:52:57 -0800 Subject: [PATCH] added from suggestions --- src/crewai/agent.py | 37 ++++++++++--------- src/crewai/crew.py | 20 +++++----- src/crewai/knowledge/knowledge.py | 29 +++++++++++++-- .../knowledge/source/base_knowledge_source.py | 2 +- .../source/string_knowledge_source.py | 2 +- .../knowledge/storage/knowledge_storage.py | 17 ++++++--- 6 files changed, 69 insertions(+), 38 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 4b163e05a..564b59372 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -11,6 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.cli.constants import ENV_VARS from crewai.llm import LLM from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.tools import BaseTool from crewai.tools.agent_tools.agent_tools import AgentTools @@ -121,10 +122,14 @@ class Agent(BaseAgent): default="safe", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", ) - knowledge: Optional[Dict[str, Any]] = Field( + knowledge: Optional[Union[List[BaseKnowledgeSource], Knowledge]] = Field( default=None, description="Knowledge for the agent. Add knowledge sources to the knowledge object.", ) + embedder_config: Optional[Dict[str, Any]] = Field( + default=None, + description="Embedder configuration for the agent.", + ) @model_validator(mode="after") def post_init_setup(self): @@ -245,21 +250,15 @@ class Agent(BaseAgent): try: if self.knowledge: knowledge_agent_name = f"{self.role.replace(' ', '_')}" - if isinstance(self.knowledge, dict): - knowledge_data = self.knowledge.copy() - knowledge_data["store_dir"] = knowledge_agent_name - self.knowledge = Knowledge(**knowledge_data) - self.knowledge.storage.initialize_knowledge_storage() - try: - for source in self.knowledge.sources: - source.storage = self.knowledge.storage - source.add() - except Exception as e: - self._logger.log( - "warning", - f"Failed to init knowledge: {knowledge_agent_name} {e}", - color="yellow", - ) + print("knowledge_agent_name", knowledge_agent_name) + if isinstance(self.knowledge, list) and all( + isinstance(k, BaseKnowledgeSource) for k in self.knowledge + ): + self.knowledge = Knowledge( + sources=self.knowledge, + embedder_config=self.embedder_config, + collection_name=knowledge_agent_name, + ) except (TypeError, ValueError) as e: raise ValueError(f"Invalid Knowledge Configuration: {str(e)}") @@ -303,7 +302,7 @@ class Agent(BaseAgent): if self.knowledge and isinstance(self.knowledge, Knowledge): agent_knowledge_snippets = self.knowledge.query([task.prompt()]) - agent_knowledge_context = self._extract_knowledge_context( + agent_knowledge_context = self.knowledge.extract_knowledge_context( agent_knowledge_snippets ) if agent_knowledge_context: @@ -312,7 +311,9 @@ class Agent(BaseAgent): if self.crew and self.crew.knowledge: knowledge_snippets = self.crew.knowledge.query([task.prompt()]) - crew_knowledge_context = self._extract_knowledge_context(knowledge_snippets) + crew_knowledge_context = self.crew.knowledge.extract_knowledge_context( + knowledge_snippets + ) if crew_knowledge_context: task_prompt += crew_knowledge_context diff --git a/src/crewai/crew.py b/src/crewai/crew.py index c06bd611c..2e3da769f 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -28,6 +28,7 @@ from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.memory.user.user_memory import UserMemory from crewai.process import Process from crewai.task import Task @@ -202,7 +203,7 @@ class Crew(BaseModel): default=[], description="List of execution logs for tasks", ) - knowledge: Optional[Dict[str, Any]] = Field( + knowledge: Optional[Union[List[BaseKnowledgeSource], Knowledge]] = Field( default=None, description="Knowledge for the crew. Add knowledge sources to the knowledge object.", ) @@ -285,16 +286,15 @@ class Crew(BaseModel): """Create the knowledge for the crew.""" if self.knowledge: try: - self.knowledge = ( - Knowledge(**self.knowledge, store_dir="crew") - if isinstance(self.knowledge, dict) - else self.knowledge - ) - self.knowledge.storage.initialize_knowledge_storage() + if isinstance(self.knowledge, list) and all( + isinstance(k, BaseKnowledgeSource) for k in self.knowledge + ): + self.knowledge = Knowledge( + sources=self.knowledge, + embedder_config=self.embedder, + collection_name="crew", + ) - for source in self.knowledge.sources: - source.storage = self.knowledge.storage - source.add() except Exception as e: self._logger.log( "warning", f"Failed to init knowledge: {e}", color="yellow" diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index c5911898a..cd5b0bfba 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -23,11 +23,12 @@ class Knowledge(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) embedder_config: Optional[Dict[str, Any]] = None - store_dir: Optional[str] = None + collection_name: Optional[str] = None def __init__( self, - store_dir: str, + collection_name: str, + sources: List[BaseKnowledgeSource], embedder_config: Optional[Dict[str, Any]] = None, storage: Optional[KnowledgeStorage] = None, **data, @@ -37,8 +38,13 @@ class Knowledge(BaseModel): self.storage = storage else: self.storage = KnowledgeStorage( - embedder_config=embedder_config, store_dir=store_dir + embedder_config=embedder_config, collection_name=collection_name ) + self.sources = sources + self.storage.initialize_knowledge_storage() + for source in sources: + source.storage = self.storage + source.add() def query( self, query: List[str], limit: int = 3, preference: Optional[str] = None @@ -55,3 +61,20 @@ class Knowledge(BaseModel): score_threshold=DEFAULT_SCORE_THRESHOLD, ) return results + + def extract_knowledge_context( + self, knowledge_snippets: List[Dict[str, Any]] + ) -> str: + """Extract knowledge from the task prompt.""" + valid_snippets = [ + result["context"] + for result in knowledge_snippets + if result and result.get("context") + ] + snippet = "\n".join(valid_snippets) + return f"Additional Information: {snippet}" if valid_snippets else "" + + def _add_sources(self): + for source in self.sources: + source.storage = self.storage + source.add() diff --git a/src/crewai/knowledge/source/base_knowledge_source.py b/src/crewai/knowledge/source/base_knowledge_source.py index d62f08709..6be76ca40 100644 --- a/src/crewai/knowledge/source/base_knowledge_source.py +++ b/src/crewai/knowledge/source/base_knowledge_source.py @@ -18,7 +18,7 @@ class BaseKnowledgeSource(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) metadata: Dict[str, Any] = Field(default_factory=dict) - store_dir: Optional[str] = Field(default=None) + collection_name: Optional[str] = Field(default=None) @abstractmethod def load_content(self) -> Dict[Any, str]: diff --git a/src/crewai/knowledge/source/string_knowledge_source.py b/src/crewai/knowledge/source/string_knowledge_source.py index 7e33d7fe4..7336fd3ea 100644 --- a/src/crewai/knowledge/source/string_knowledge_source.py +++ b/src/crewai/knowledge/source/string_knowledge_source.py @@ -9,7 +9,7 @@ class StringKnowledgeSource(BaseKnowledgeSource): """A knowledge source that stores and queries plain text content using embeddings.""" content: str = Field(...) - store_dir: Optional[str] = Field(default=None) + collection_name: Optional[str] = Field(default=None) def model_post_init(self, _): """Post-initialization method to validate content.""" diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index e23ddd0f2..58024895d 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -35,16 +35,16 @@ class KnowledgeStorage(BaseKnowledgeStorage): """ collection: Optional[chromadb.Collection] = None - store_dir: Optional[str] = "knowledge" + collection_name: Optional[str] = "knowledge" app: Optional[chromadb.PersistentClient] = None def __init__( self, embedder_config: Optional[Dict[str, Any]] = None, - store_dir: Optional[str] = None, + collection_name: Optional[str] = None, ): self.embedder_config = embedder_config - self.store_dir = store_dir + self.collection_name = collection_name def search( self, @@ -85,9 +85,16 @@ class KnowledgeStorage(BaseKnowledgeStorage): try: collection_name = ( - f"knowledge_{self.store_dir}" if self.store_dir else "knowledge" + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" ) - self.collection = self.app.get_or_create_collection(name=collection_name) + if self.app: + self.collection = self.app.get_or_create_collection( + name=collection_name + ) + else: + raise Exception("Vector Database Client not initialized") except Exception: raise Exception("Failed to create or get collection")