This commit is contained in:
Lorenze Jay
2025-01-27 13:31:00 -08:00
parent 1de204eff8
commit 9b88bcd97e
2 changed files with 14 additions and 10 deletions

View File

@@ -129,9 +129,9 @@ class Agent(BaseAgent):
default=None,
description="Embedder configuration for the agent.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)
# knowledge: Optional[Knowledge] = PrivateAttr(
# default=None,
# )
@model_validator(mode="after")
def post_init_setup(self):
@@ -162,7 +162,7 @@ class Agent(BaseAgent):
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self._knowledge = Knowledge(
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder=self.embedder,
collection_name=knowledge_agent_name,
@@ -225,8 +225,8 @@ class Agent(BaseAgent):
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
if self._knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets

View File

@@ -211,8 +211,9 @@ class Crew(BaseModel):
default=None,
description="LLM used to handle chatting with the crew.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
knowledge: Optional[Knowledge] = Field(
default=None,
description="Knowledge for the crew.",
)
@field_validator("id", mode="before")
@@ -290,7 +291,7 @@ class Crew(BaseModel):
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self._knowledge = Knowledge(
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder_config=self.embedder,
collection_name="crew",
@@ -992,8 +993,8 @@ class Crew(BaseModel):
return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
if self._knowledge:
return self._knowledge.query(query)
if self.knowledge:
return self.knowledge.query(query)
return None
def fetch_inputs(self) -> Set[str]:
@@ -1038,6 +1039,7 @@ class Crew(BaseModel):
"agents",
"tasks",
"knowledge_sources",
"knowledge",
}
cloned_agents = [agent.copy() for agent in self.agents]
@@ -1046,6 +1048,7 @@ class Crew(BaseModel):
cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
existing_knowledge = shallow_copy(self.knowledge)
for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping)
@@ -1071,6 +1074,7 @@ class Crew(BaseModel):
agents=cloned_agents,
tasks=cloned_tasks,
knowledge_sources=existing_knowledge_sources,
knowledge=existing_knowledge,
)
return copied_crew