diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 21b586cd7..8c31dd139 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -25,7 +25,6 @@ from pydantic import ( BaseModel, ConfigDict, Field, - InstanceOf, PrivateAttr, model_validator, ) @@ -167,10 +166,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: str | InstanceOf[BaseLLM] | None = Field( + llm: str | BaseLLM | None = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: str | InstanceOf[BaseLLM] | None = Field( + function_calling_llm: str | BaseLLM | None = Field( description="Language model that will run the agent.", default=None ) system_template: str | None = Field( diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 9949343e2..ce5682266 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -12,7 +12,6 @@ from pydantic import ( UUID4, BaseModel, Field, - InstanceOf, PrivateAttr, field_validator, model_validator, @@ -185,7 +184,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="Knowledge sources for the agent.", ) - knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field( + knowledge_storage: BaseKnowledgeStorage | None = Field( default=None, description="Custom knowledge storage for the agent.", ) diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 00fbae78f..00107b063 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -22,7 +22,6 @@ from pydantic import ( UUID4, BaseModel, Field, - InstanceOf, Json, PrivateAttr, field_validator, @@ -176,7 +175,7 @@ class Crew(FlowTrackable, BaseModel): _rpm_controller: RPMController = PrivateAttr() _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() - _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler) + _cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler) _memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None) _train: bool | None = PrivateAttr(default=False) _train_iteration: int | None = PrivateAttr() @@ -210,13 +209,13 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: str | InstanceOf[BaseLLM] | None = Field( + manager_llm: str | BaseLLM | None = Field( description="Language model that will run the agent.", default=None ) manager_agent: BaseAgent | None = Field( description="Custom agent that will be used as manager.", default=None ) - function_calling_llm: str | InstanceOf[LLM] | None = Field( + function_calling_llm: str | LLM | None = Field( description="Language model that will run the agent.", default=None ) config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None) @@ -267,7 +266,7 @@ class Crew(FlowTrackable, BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + planning_llm: str | BaseLLM | Any | None = Field( default=None, description=( "Language model that will run the AgentPlanner if planning is True." @@ -288,7 +287,7 @@ class Crew(FlowTrackable, BaseModel): "knowledge object." ), ) - chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + chat_llm: str | BaseLLM | Any | None = Field( default=None, description="LLM used to handle chatting with the crew.", ) @@ -1800,7 +1799,7 @@ class Crew(FlowTrackable, BaseModel): def test( self, n_iterations: int, - eval_llm: str | InstanceOf[BaseLLM], + eval_llm: str | BaseLLM, inputs: dict[str, Any] | None = None, ) -> None: """Test and evaluate the Crew with the given inputs for n iterations. diff --git a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py index e8a2054f7..ea8aff734 100644 --- a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -3,12 +3,15 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from pydantic import BaseModel, ConfigDict + if TYPE_CHECKING: from crewai.rag.types import SearchResult -class BaseKnowledgeStorage(ABC): +class BaseKnowledgeStorage(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) """Abstract base class for knowledge storage implementations.""" @abstractmethod diff --git a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py index cfcbca25a..3c9615946 100644 --- a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py @@ -3,6 +3,9 @@ import traceback from typing import Any, cast import warnings +from pydantic import Field, PrivateAttr, model_validator +from typing_extensions import Self + from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper @@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage): search efficiency. """ - def __init__( - self, - embedder: ProviderSpec + collection_name: str | None = None + embedder: ( + ProviderSpec | BaseEmbeddingsProvider[Any] | type[BaseEmbeddingsProvider[Any]] - | None = None, - collection_name: str | None = None, - ) -> None: - self.collection_name = collection_name - self._client: BaseClient | None = None + | None + ) = Field(default=None, exclude=True) + _client: BaseClient | None = PrivateAttr(default=None) + @model_validator(mode="after") + def _init_client(self) -> Self: warnings.filterwarnings( "ignore", message=r".*'model_fields'.*is deprecated.*", module=r"^chromadb(\.|$)", ) - if embedder: - embedding_function = build_embedder(embedder) # type: ignore[arg-type] + if self.embedder: + embedding_function = build_embedder(self.embedder) # type: ignore[arg-type] config = ChromaDBConfig( embedding_function=cast( ChromaEmbeddingFunctionWrapper, embedding_function ) ) self._client = create_client(config) + return self def _get_client(self) -> BaseClient: """Get the appropriate client - instance-specific or global.""" diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index 4e7d22280..bbb464010 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -22,7 +22,6 @@ from pydantic import ( UUID4, BaseModel, Field, - InstanceOf, PrivateAttr, field_validator, model_validator, @@ -204,7 +203,7 @@ class LiteAgent(FlowTrackable, BaseModel): role: str = Field(description="Role of the agent") goal: str = Field(description="Goal of the agent") backstory: str = Field(description="Backstory of the agent") - llm: str | InstanceOf[BaseLLM] | Any | None = Field( + llm: str | BaseLLM | Any | None = Field( default=None, description="Language model that will run the agent" ) tools: list[BaseTool] = Field( diff --git a/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py index 32b847d73..9dbfbcb86 100644 --- a/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, Field, InstanceOf +from pydantic import BaseModel, Field from rich.box import HEAVY_EDGE from rich.console import Console from rich.table import Table @@ -39,9 +39,9 @@ class CrewEvaluator: def __init__( self, crew: Crew, - eval_llm: InstanceOf[BaseLLM] | str | None = None, + eval_llm: BaseLLM | str | None = None, openai_model_name: str | None = None, - llm: InstanceOf[BaseLLM] | str | None = None, + llm: BaseLLM | str | None = None, ) -> None: self.crew = crew self.llm = eval_llm diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index d865ec541..7706f9ade 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -1692,9 +1692,27 @@ def test_agent_with_knowledge_sources_works_with_copy(): ) as mock_knowledge_storage: from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage - mock_knowledge_storage_instance = mock_knowledge_storage.return_value - mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage - agent.knowledge_storage = mock_knowledge_storage_instance + class _StubStorage(BaseKnowledgeStorage): + def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6): + return [] + + async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6): + return [] + + def save(self, documents): + pass + + async def asave(self, documents): + pass + + def reset(self): + pass + + async def areset(self): + pass + + mock_knowledge_storage.return_value = _StubStorage() + agent.knowledge_storage = _StubStorage() agent_copy = agent.copy() diff --git a/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py b/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py index a58dcb2fc..5a228cde4 100644 --- a/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py +++ b/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py @@ -132,12 +132,12 @@ def test_embedding_configuration_flow( embedder_config = { "provider": "sentence-transformer", - "model_name": "all-MiniLM-L6-v2", + "config": {"model_name": "all-MiniLM-L6-v2"}, } - KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test") + storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test") - mock_get_embedding.assert_called_once_with(embedder_config) + mock_get_embedding.assert_called_once_with(storage.embedder) @patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") diff --git a/lib/crewai/tests/rag/test_error_handling.py b/lib/crewai/tests/rag/test_error_handling.py index fab568e14..6a2962806 100644 --- a/lib/crewai/tests/rag/test_error_handling.py +++ b/lib/crewai/tests/rag/test_error_handling.py @@ -3,6 +3,8 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError + from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped] KnowledgeStorage, ) @@ -59,7 +61,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) "Unsupported provider: invalid_provider" ) - with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"): + with pytest.raises(ValidationError): KnowledgeStorage( embedder={"provider": "invalid_provider"}, collection_name="invalid_embedding_test",