This commit is contained in:
Lorenze Jay
2025-01-27 13:31:00 -08:00
parent 1de204eff8
commit 9b88bcd97e
2 changed files with 14 additions and 10 deletions

View File

@@ -129,9 +129,9 @@ class Agent(BaseAgent):
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
_knowledge: Optional[Knowledge] = PrivateAttr( # knowledge: Optional[Knowledge] = PrivateAttr(
default=None, # default=None,
) # )
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
@@ -162,7 +162,7 @@ class Agent(BaseAgent):
if isinstance(self.knowledge_sources, list) and all( if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
): ):
self._knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder=self.embedder, embedder=self.embedder,
collection_name=knowledge_agent_name, collection_name=knowledge_agent_name,
@@ -225,8 +225,8 @@ class Agent(BaseAgent):
if memory.strip() != "": if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory) task_prompt += self.i18n.slice("memory").format(memory=memory)
if self._knowledge: if self.knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()]) agent_knowledge_snippets = self.knowledge.query([task.prompt()])
if agent_knowledge_snippets: if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context( agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets agent_knowledge_snippets

View File

@@ -211,8 +211,9 @@ class Crew(BaseModel):
default=None, default=None,
description="LLM used to handle chatting with the crew.", description="LLM used to handle chatting with the crew.",
) )
_knowledge: Optional[Knowledge] = PrivateAttr( knowledge: Optional[Knowledge] = Field(
default=None, default=None,
description="Knowledge for the crew.",
) )
@field_validator("id", mode="before") @field_validator("id", mode="before")
@@ -290,7 +291,7 @@ class Crew(BaseModel):
if isinstance(self.knowledge_sources, list) and all( if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
): ):
self._knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder_config=self.embedder, embedder_config=self.embedder,
collection_name="crew", collection_name="crew",
@@ -992,8 +993,8 @@ class Crew(BaseModel):
return result return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]: def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
if self._knowledge: if self.knowledge:
return self._knowledge.query(query) return self.knowledge.query(query)
return None return None
def fetch_inputs(self) -> Set[str]: def fetch_inputs(self) -> Set[str]:
@@ -1038,6 +1039,7 @@ class Crew(BaseModel):
"agents", "agents",
"tasks", "tasks",
"knowledge_sources", "knowledge_sources",
"knowledge",
} }
cloned_agents = [agent.copy() for agent in self.agents] cloned_agents = [agent.copy() for agent in self.agents]
@@ -1046,6 +1048,7 @@ class Crew(BaseModel):
cloned_tasks = [] cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources) existing_knowledge_sources = shallow_copy(self.knowledge_sources)
existing_knowledge = shallow_copy(self.knowledge)
for task in self.tasks: for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping) cloned_task = task.copy(cloned_agents, task_mapping)
@@ -1071,6 +1074,7 @@ class Crew(BaseModel):
agents=cloned_agents, agents=cloned_agents,
tasks=cloned_tasks, tasks=cloned_tasks,
knowledge_sources=existing_knowledge_sources, knowledge_sources=existing_knowledge_sources,
knowledge=existing_knowledge,
) )
return copied_crew return copied_crew