improvements from review

This commit is contained in:
Lorenze Jay
2024-11-20 13:32:00 -08:00
parent 3c4504bd4f
commit 44ab749fda
7 changed files with 25 additions and 43 deletions

View File

@@ -121,7 +121,6 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
@model_validator(mode="after")
def post_init_setup(self):
@@ -230,12 +229,6 @@ class Agent(BaseAgent):
if self.allow_code_execution:
self._validate_docker_installation()
# Initialize the Knowledge object if knowledge_sources are provided
if self.crew and self.crew.knowledge_store:
self._knowledge = self.crew.knowledge_store
else:
self._knowledge = None
return self
def _setup_agent_executor(self):
@@ -282,19 +275,16 @@ class Agent(BaseAgent):
task_prompt += self.i18n.slice("memory").format(memory=memory)
# Integrate the knowledge base
if self.crew and self.crew.knowledge_store:
knowledge_snippets: List[Dict[str, Any]] = self.crew.knowledge_store.query(
[task.prompt()]
)
if knowledge_snippets:
valid_snippets = [
result["context"]
for result in knowledge_snippets
if result and result.get("context")
]
if valid_snippets:
formatted_knowledge = "\n".join(valid_snippets)
task_prompt += f"\n\nAdditional Information:\n{formatted_knowledge}"
if self.crew and self.crew.knowledge:
knowledge_snippets = self.crew.knowledge.query([task.prompt()])
valid_snippets = [
result["context"]
for result in knowledge_snippets
if result and result.get("context")
]
if valid_snippets:
formatted_knowledge = "\n".join(valid_snippets)
task_prompt += f"\n\nAdditional Information:\n{formatted_knowledge}"
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)

View File

@@ -195,18 +195,10 @@ class Crew(BaseModel):
default=[],
description="List of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
knowledge: Optional[Dict[str, Any]] = Field(
default=None, description="Knowledge for the crew. Add knowledge sources to the knowledge object."
)
knowledge_store: Optional[Knowledge] = Field(
default=None, description="Knowledge Source for the crew."
)
knowledge: Optional[bool] = Field(
default=False,
description="Whether the crew should use knowledge to store memories of it's execution",
)
@field_validator("id", mode="before")
@classmethod
@@ -284,9 +276,7 @@ class Crew(BaseModel):
@model_validator(mode="after")
def create_crew_knowledge(self) -> "Crew":
if self.knowledge:
self.knowledge_store = Knowledge(
sources=self.knowledge_sources or [], embedder_config=self.embedder
)
self.knowledge = Knowledge(**self.knowledge)
return self
@model_validator(mode="after")

View File

@@ -18,10 +18,7 @@ class Knowledge(BaseModel):
def __init__(self, embedder_config: Optional[Dict[str, Any]] = None, **data):
super().__init__(**data)
if embedder_config:
self.storage = KnowledgeStorage(embedder_config=embedder_config)
else:
self.storage = KnowledgeStorage()
self.storage = KnowledgeStorage(embedder_config=embedder_config or None)
try:
for source in self.sources:

View File

@@ -14,9 +14,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource):
file_path: Union[Path, List[Path]] = Field(...)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
metadata: Dict[str, Any] = Field(default_factory=dict)
def model_post_init(self, context):
def model_post_init(self, _):
"""Post-initialization method to load content."""
self.content = self.load_content()

View File

@@ -10,7 +10,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
content: str = Field(...)
def model_post_init(self, context):
def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.load_content()

View File

@@ -9,6 +9,7 @@ from typing import Optional, List
from typing import Dict, Any
from crewai.utilities import EmbeddingConfigurator
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
import hashlib
@contextlib.contextmanager
@@ -95,10 +96,14 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if self.collection:
metadatas = [metadata] if isinstance(metadata, dict) else metadata
self.collection.add(
ids = [
hashlib.sha256(doc.encode("utf-8")).hexdigest() for doc in documents
]
self.collection.upsert(
documents=documents,
metadatas=metadatas,
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
ids=ids,
)
else:
raise Exception("Collection not initialized")

View File

@@ -38,6 +38,7 @@ def mock_crew_factory():
crew = MockCrew()
crew.name = name
crew.knowledge = None
task_output = TaskOutput(
description="Test task", raw="Task output", agent="Test Agent"