fix copy and custom storage

This commit is contained in:
Lorenze Jay
2025-01-24 12:17:31 -08:00
parent 591c4a511b
commit 71246e9de1
3 changed files with 8 additions and 6 deletions

View File

@@ -165,7 +165,7 @@ class Agent(BaseAgent):
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder_config=self.embedder_config, embedder_config=self.embedder_config,
collection_name=knowledge_agent_name, collection_name=knowledge_agent_name,
storage=self.custom_knowledge_storage or None, storage=self.knowledge_storage or None,
) )
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}") raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")

View File

@@ -135,7 +135,7 @@ class BaseAgent(ABC, BaseModel):
default=None, default=None,
description="Knowledge sources for the agent.", description="Knowledge sources for the agent.",
) )
custom_knowledge_storage: Optional[Any] = Field( knowledge_storage: Optional[Any] = Field(
default=None, default=None,
description="Custom knowledge storage for the agent.", description="Custom knowledge storage for the agent.",
) )
@@ -270,13 +270,14 @@ class BaseAgent(ABC, BaseModel):
# Copy llm and clear callbacks # Copy llm and clear callbacks
existing_llm = shallow_copy(self.llm) existing_llm = shallow_copy(self.llm)
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)( copied_agent = type(self)(
**copied_data, **copied_data,
llm=existing_llm, llm=existing_llm,
tools=self.tools, tools=self.tools,
knowledge_sources=getattr(self, "knowledge_sources", None), knowledge_sources=existing_knowledge_sources,
) )
return copied_agent return copied_agent

View File

@@ -4,6 +4,7 @@ import re
import uuid import uuid
import warnings import warnings
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -1044,6 +1045,8 @@ class Crew(BaseModel):
task_mapping = {} task_mapping = {}
cloned_tasks = [] cloned_tasks = []
knowledge_sources_copied = shallow_copy(self.knowledge_sources)
for task in self.tasks: for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping) cloned_task = task.copy(cloned_agents, task_mapping)
cloned_tasks.append(cloned_task) cloned_tasks.append(cloned_task)
@@ -1067,9 +1070,7 @@ class Crew(BaseModel):
**copied_data, **copied_data,
agents=cloned_agents, agents=cloned_agents,
tasks=cloned_tasks, tasks=cloned_tasks,
knowledge_sources=self.knowledge_sources knowledge_sources=knowledge_sources_copied,
if hasattr(self, "knowledge_sources")
else None,
) )
return copied_crew return copied_crew