Merge pull request #172 from crewAIInc/fix/weaviate-init

fix: weaviate init parameters
This commit is contained in:
Eduardo Chiarotti
2025-01-03 08:56:45 -03:00
committed by GitHub

View File

@@ -1,10 +1,12 @@
from typing import Any, Type, Optional
import os
import json import json
import os
from typing import Any, Optional, Type
try: try:
import weaviate import weaviate
from weaviate.classes.config import Configure, Vectorizers from weaviate.classes.config import Configure, Vectorizers
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
WEAVIATE_AVAILABLE = True WEAVIATE_AVAILABLE = True
except ImportError: except ImportError:
WEAVIATE_AVAILABLE = False WEAVIATE_AVAILABLE = False
@@ -14,6 +16,7 @@ except ImportError:
Auth = Any Auth = Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -34,16 +37,8 @@ class WeaviateVectorSearchTool(BaseTool):
args_schema: Type[BaseModel] = WeaviateToolSchema args_schema: Type[BaseModel] = WeaviateToolSchema
query: Optional[str] = None query: Optional[str] = None
vectorizer: Optional[Vectorizers] = Field( vectorizer: Optional[Vectorizers] = None
default=Configure.Vectorizer.text2vec_openai( generative_model: Optional[str] = None
model="nomic-embed-text",
)
)
generative_model: Optional[str] = Field(
default=Configure.Generative.openai(
model="gpt-4o",
),
)
collection_name: Optional[str] = None collection_name: Optional[str] = None
limit: Optional[int] = Field(default=3) limit: Optional[int] = Field(default=3)
headers: Optional[dict] = Field( headers: Optional[dict] = Field(
@@ -58,6 +53,19 @@ class WeaviateVectorSearchTool(BaseTool):
description="The API key for the Weaviate cluster", description="The API key for the Weaviate cluster",
) )
def __init__(self, **kwargs):
super().__init__(**kwargs)
if WEAVIATE_AVAILABLE:
self.vectorizer = self.vectorizer or Configure.Vectorizer.text2vec_openai(
model="nomic-embed-text",
)
self.generative_model = (
self.generative_model
or Configure.Generative.openai(
model="gpt-4o",
)
)
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
if not WEAVIATE_AVAILABLE: if not WEAVIATE_AVAILABLE:
raise ImportError( raise ImportError(