From 2b438baad4891852957c235e2ab888766d143d03 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 14 Feb 2025 15:11:58 -0500 Subject: [PATCH] Fix issues --- src/crewai/agent.py | 2 -- src/crewai/agents/agent_builder/base_agent.py | 5 ++-- src/crewai/crew.py | 23 +++++++++++++++++++ tests/crew_test.py | 16 ++++--------- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 0f107c307..f0a8c5ffa 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -175,10 +175,8 @@ class Agent(BaseAgent): Returns: Output of the agent """ - # The RPM controller is now managed by the Crew, so no need to set it here. if self.tools_handler: self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalli - task_prompt = task.prompt() # If the task requires output in JSON or Pydantic format, diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 9910031d8..cf4fca304 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -380,5 +380,6 @@ class BaseAgent(ABC, BaseModel): """ if self.cache: self.set_cache_handler(cache_handler) - if self.max_rpm: - self.set_rpm_controller() + # Use the injected RPM controller rather than auto-creating one + if rpm_controller: + self.set_rpm_controller(rpm_controller) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index b5a83d078..733f8f4c6 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -255,6 +255,29 @@ class Crew(BaseModel): self._telemetry.set_tracer() return self + @model_validator(mode="after") + def initialize_dependencies(self) -> "Crew": + # Create a cache handler if caching is enabled + if self.cache: + self._cache_handler = CacheHandler() + else: + self._cache_handler = None + + # Create the Crew-level RPM controller if a max RPM is specified + if self.max_rpm is not None: + self._rpm_controller = RPMController( + max_rpm=self.max_rpm, logger=Logger(verbose=self.verbose) + ) + else: + self._rpm_controller = None + + # Now inject these external dependencies into each agent + for agent in self.agents: + agent.crew = self # ensure the agent's crew reference is set + agent.configure_executor(self._cache_handler, self._rpm_controller) + + return self + @model_validator(mode="after") def create_crew_memory(self) -> "Crew": """Set private attributes.""" diff --git a/tests/crew_test.py b/tests/crew_test.py index 0539ea347..8b74c75b5 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -9,7 +9,6 @@ from unittest.mock import MagicMock, patch import instructor import pydantic_core import pytest - from crewai.agent import Agent from crewai.agents.cache import CacheHandler from crewai.crew import Crew @@ -541,9 +540,8 @@ def test_crew_with_delegating_agents(): def test_crew_with_delegating_agents_should_not_override_task_tools(): from typing import Type - from pydantic import BaseModel, Field - from crewai.tools import BaseTool + from pydantic import BaseModel, Field class TestToolInput(BaseModel): """Input schema for TestTool.""" @@ -603,9 +601,8 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(): def test_crew_with_delegating_agents_should_not_override_agent_tools(): from typing import Type - from pydantic import BaseModel, Field - from crewai.tools import BaseTool + from pydantic import BaseModel, Field class TestToolInput(BaseModel): """Input schema for TestTool.""" @@ -667,9 +664,8 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(): def test_task_tools_override_agent_tools(): from typing import Type - from pydantic import BaseModel, Field - from crewai.tools import BaseTool + from pydantic import BaseModel, Field class TestToolInput(BaseModel): """Input schema for TestTool.""" @@ -725,9 +721,8 @@ def test_task_tools_override_agent_tools_with_allow_delegation(): """ from typing import Type - from pydantic import BaseModel, Field - from crewai.tools import BaseTool + from pydantic import BaseModel, Field class TestToolInput(BaseModel): query: str = Field(..., description="Query to process") @@ -3429,11 +3424,10 @@ def test_task_tools_preserve_code_execution_tools(): """ from typing import Type + from crewai.tools import BaseTool from crewai_tools import CodeInterpreterTool from pydantic import BaseModel, Field - from crewai.tools import BaseTool - class TestToolInput(BaseModel): """Input schema for TestTool."""