fix breakage when cloning agent/crew using knowledge_sources and enable custom knowledge_storage (#1927)

* fix breakage when cloning agent/crew using knowledge_sources

* fixed typo

* better

* ensure use of other knowledge storage works

* fix copy and custom storage

* added tests

* normalized name

* updated cassette

* fix test

* remove fixture

* fixed test

* fix

* add fixture to this

* add fixture to this

* patch twice since

* fix again

* with fixtures

* better mocks

* fix

* simple

* try

* another

* hopefully fixes test

* hopefully fixes test

* this should fix it !

* WIP: test check with prints

* try this

* exclude knowledge

* fixes

* just drop clone for now

* rm print statements

* printing agent_copy

* checker

* linted

* cleanup

* better docs

---------

Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
Lorenze Jay
2025-01-29 06:37:22 -08:00
committed by GitHub
parent c3e7a3ec19
commit a3ad2c1957
13 changed files with 626 additions and 92 deletions

View File

@@ -4,6 +4,7 @@ import re
import uuid
import warnings
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -210,8 +211,9 @@ class Crew(BaseModel):
default=None,
description="LLM used to handle chatting with the crew.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
knowledge: Optional[Knowledge] = Field(
default=None,
description="Knowledge for the crew.",
)
@field_validator("id", mode="before")
@@ -289,7 +291,7 @@ class Crew(BaseModel):
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self._knowledge = Knowledge(
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder_config=self.embedder,
collection_name="crew",
@@ -991,8 +993,8 @@ class Crew(BaseModel):
return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
if self._knowledge:
return self._knowledge.query(query)
if self.knowledge:
return self.knowledge.query(query)
return None
def fetch_inputs(self) -> Set[str]:
@@ -1036,6 +1038,8 @@ class Crew(BaseModel):
"_telemetry",
"agents",
"tasks",
"knowledge_sources",
"knowledge",
}
cloned_agents = [agent.copy() for agent in self.agents]
@@ -1043,6 +1047,9 @@ class Crew(BaseModel):
task_mapping = {}
cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
existing_knowledge = shallow_copy(self.knowledge)
for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping)
cloned_tasks.append(cloned_task)
@@ -1062,7 +1069,13 @@ class Crew(BaseModel):
copied_data.pop("agents", None)
copied_data.pop("tasks", None)
copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks)
copied_crew = Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
knowledge_sources=existing_knowledge_sources,
knowledge=existing_knowledge,
)
return copied_crew