diff --git a/src/crewai_tools/tools/weaviate_tool/vector_search.py b/src/crewai_tools/tools/weaviate_tool/vector_search.py index a9c7ce519..fc5641009 100644 --- a/src/crewai_tools/tools/weaviate_tool/vector_search.py +++ b/src/crewai_tools/tools/weaviate_tool/vector_search.py @@ -1,10 +1,12 @@ -from typing import Any, Type, Optional -import os import json +import os +from typing import Any, Optional, Type + try: import weaviate from weaviate.classes.config import Configure, Vectorizers from weaviate.classes.init import Auth + WEAVIATE_AVAILABLE = True except ImportError: WEAVIATE_AVAILABLE = False @@ -14,6 +16,7 @@ except ImportError: Auth = Any from pydantic import BaseModel, Field + from crewai.tools import BaseTool @@ -34,16 +37,8 @@ class WeaviateVectorSearchTool(BaseTool): args_schema: Type[BaseModel] = WeaviateToolSchema query: Optional[str] = None - vectorizer: Optional[Vectorizers] = Field( - default=Configure.Vectorizer.text2vec_openai( - model="nomic-embed-text", - ) - ) - generative_model: Optional[str] = Field( - default=Configure.Generative.openai( - model="gpt-4o", - ), - ) + vectorizer: Optional[Vectorizers] = None + generative_model: Optional[str] = None collection_name: Optional[str] = None limit: Optional[int] = Field(default=3) headers: Optional[dict] = Field( @@ -58,6 +53,19 @@ class WeaviateVectorSearchTool(BaseTool): description="The API key for the Weaviate cluster", ) + def __init__(self, **kwargs): + super().__init__(**kwargs) + if WEAVIATE_AVAILABLE: + self.vectorizer = self.vectorizer or Configure.Vectorizer.text2vec_openai( + model="nomic-embed-text", + ) + self.generative_model = ( + self.generative_model + or Configure.Generative.openai( + model="gpt-4o", + ) + ) + def _run(self, query: str) -> str: if not WEAVIATE_AVAILABLE: raise ImportError(