From f3004ffb2baadb2dc3090642b0e293f0c061167c Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Sun, 19 Jan 2025 15:58:01 -0800 Subject: [PATCH] fix breakage when cloning agent/crew using knowledge_sources --- src/crewai/agents/agent_builder/base_agent.py | 15 ++++++++++++++- src/crewai/crew.py | 10 +++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 207a1769a..87922834e 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.tools_handler import ToolsHandler +from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.tools import BaseTool from crewai.tools.base_tool import Tool from crewai.utilities import I18N, Logger, RPMController @@ -130,6 +131,10 @@ class BaseAgent(ABC, BaseModel): max_tokens: Optional[int] = Field( default=None, description="Maximum number of tokens for the agent's execution." ) + knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( + default=None, + description="Knowledge sources for the agent.", + ) @model_validator(mode="before") @classmethod @@ -256,13 +261,21 @@ class BaseAgent(ABC, BaseModel): "tools_handler", "cache_handler", "llm", + "knowledge_sources", } # Copy llm and clear callbacks existing_llm = shallow_copy(self.llm) copied_data = self.model_dump(exclude=exclude) copied_data = {k: v for k, v in copied_data.items() if v is not None} - copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools) + copied_agent = type(self)( + **copied_data, + llm=existing_llm, + tools=self.tools, + knowledge_sources=self.knowledge_sources + if hasattr(self, "knowledge_sources") + else None, + ) return copied_agent diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 38b96a0e0..880ffe758 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1036,6 +1036,7 @@ class Crew(BaseModel): "_telemetry", "agents", "tasks", + "knowledge_source", } cloned_agents = [agent.copy() for agent in self.agents] @@ -1062,7 +1063,14 @@ class Crew(BaseModel): copied_data.pop("agents", None) copied_data.pop("tasks", None) - copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks) + copied_crew = Crew( + **copied_data, + agents=cloned_agents, + tasks=cloned_tasks, + knowledge_sources=self.knowledge_sources + if hasattr(self, "knowledge_sources") + else None, + ) return copied_crew