refactor: replace InstanceOf[T] with plain type annotations
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

* refactor: replace InstanceOf[T] with plain type annotations

InstanceOf[] is a Pydantic validation wrapper that adds runtime
isinstance checks. Plain type annotations are sufficient here since
the models already use arbitrary_types_allowed or the types are
BaseModel subclasses.

* refactor: convert BaseKnowledgeStorage to BaseModel

* fix: update tests for BaseKnowledgeStorage BaseModel conversion

* fix: correct embedder config structure in test
This commit is contained in:
Greyson LaLonde
2026-03-31 08:11:21 +08:00
committed by GitHub
parent ef79456968
commit dfc0f9a317
10 changed files with 58 additions and 35 deletions

View File

@@ -25,7 +25,6 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
@@ -167,10 +166,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: str | InstanceOf[BaseLLM] | None = Field(
llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
function_calling_llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: str | None = Field(

View File

@@ -12,7 +12,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -185,7 +184,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
knowledge_storage: BaseKnowledgeStorage | None = Field(
default=None,
description="Custom knowledge storage for the agent.",
)

View File

@@ -22,7 +22,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
@@ -176,7 +175,7 @@ class Crew(FlowTrackable, BaseModel):
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
@@ -210,13 +209,13 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
manager_llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: str | InstanceOf[LLM] | None = Field(
function_calling_llm: str | LLM | None = Field(
description="Language model that will run the agent.", default=None
)
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
@@ -267,7 +266,7 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
planning_llm: str | BaseLLM | Any | None = Field(
default=None,
description=(
"Language model that will run the AgentPlanner if planning is True."
@@ -288,7 +287,7 @@ class Crew(FlowTrackable, BaseModel):
"knowledge object."
),
)
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
chat_llm: str | BaseLLM | Any | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -1800,7 +1799,7 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: str | InstanceOf[BaseLLM],
eval_llm: str | BaseLLM,
inputs: dict[str, Any] | None = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations.

View File

@@ -3,12 +3,15 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING:
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC):
class BaseKnowledgeStorage(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Abstract base class for knowledge storage implementations."""
@abstractmethod

View File

@@ -3,6 +3,9 @@ import traceback
from typing import Any, cast
import warnings
from pydantic import Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency.
"""
def __init__(
self,
embedder: ProviderSpec
collection_name: str | None = None
embedder: (
ProviderSpec
| BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]]
| None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._client: BaseClient | None = None
| None
) = Field(default=None, exclude=True)
_client: BaseClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_client(self) -> Self:
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
if embedder:
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
if self.embedder:
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
return self
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""

View File

@@ -22,7 +22,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -204,7 +203,7 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
llm: str | BaseLLM | Any | None = Field(
default=None, description="Language model that will run the agent"
)
tools: list[BaseTool] = Field(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, InstanceOf
from pydantic import BaseModel, Field
from rich.box import HEAVY_EDGE
from rich.console import Console
from rich.table import Table
@@ -39,9 +39,9 @@ class CrewEvaluator:
def __init__(
self,
crew: Crew,
eval_llm: InstanceOf[BaseLLM] | str | None = None,
eval_llm: BaseLLM | str | None = None,
openai_model_name: str | None = None,
llm: InstanceOf[BaseLLM] | str | None = None,
llm: BaseLLM | str | None = None,
) -> None:
self.crew = crew
self.llm = eval_llm

View File

@@ -1692,9 +1692,27 @@ def test_agent_with_knowledge_sources_works_with_copy():
) as mock_knowledge_storage:
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
agent.knowledge_storage = mock_knowledge_storage_instance
class _StubStorage(BaseKnowledgeStorage):
def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
return []
async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
return []
def save(self, documents):
pass
async def asave(self, documents):
pass
def reset(self):
pass
async def areset(self):
pass
mock_knowledge_storage.return_value = _StubStorage()
agent.knowledge_storage = _StubStorage()
agent_copy = agent.copy()

View File

@@ -132,12 +132,12 @@ def test_embedding_configuration_flow(
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
"config": {"model_name": "all-MiniLM-L6-v2"},
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
mock_get_embedding.assert_called_once_with(storage.embedder)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")

View File

@@ -3,6 +3,8 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@@ -59,7 +61,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
"Unsupported provider: invalid_provider"
)
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
with pytest.raises(ValidationError):
KnowledgeStorage(
embedder={"provider": "invalid_provider"},
collection_name="invalid_embedding_test",