mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Merge pull request #172 from crewAIInc/fix/weaviate-init
fix: weaviate init parameters
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
from typing import Any, Type, Optional
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
from weaviate.classes.config import Configure, Vectorizers
|
from weaviate.classes.config import Configure, Vectorizers
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
|
|
||||||
WEAVIATE_AVAILABLE = True
|
WEAVIATE_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
WEAVIATE_AVAILABLE = False
|
WEAVIATE_AVAILABLE = False
|
||||||
@@ -14,6 +16,7 @@ except ImportError:
|
|||||||
Auth = Any
|
Auth = Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@@ -34,16 +37,8 @@ class WeaviateVectorSearchTool(BaseTool):
|
|||||||
args_schema: Type[BaseModel] = WeaviateToolSchema
|
args_schema: Type[BaseModel] = WeaviateToolSchema
|
||||||
query: Optional[str] = None
|
query: Optional[str] = None
|
||||||
|
|
||||||
vectorizer: Optional[Vectorizers] = Field(
|
vectorizer: Optional[Vectorizers] = None
|
||||||
default=Configure.Vectorizer.text2vec_openai(
|
generative_model: Optional[str] = None
|
||||||
model="nomic-embed-text",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generative_model: Optional[str] = Field(
|
|
||||||
default=Configure.Generative.openai(
|
|
||||||
model="gpt-4o",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
collection_name: Optional[str] = None
|
collection_name: Optional[str] = None
|
||||||
limit: Optional[int] = Field(default=3)
|
limit: Optional[int] = Field(default=3)
|
||||||
headers: Optional[dict] = Field(
|
headers: Optional[dict] = Field(
|
||||||
@@ -58,6 +53,19 @@ class WeaviateVectorSearchTool(BaseTool):
|
|||||||
description="The API key for the Weaviate cluster",
|
description="The API key for the Weaviate cluster",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if WEAVIATE_AVAILABLE:
|
||||||
|
self.vectorizer = self.vectorizer or Configure.Vectorizer.text2vec_openai(
|
||||||
|
model="nomic-embed-text",
|
||||||
|
)
|
||||||
|
self.generative_model = (
|
||||||
|
self.generative_model
|
||||||
|
or Configure.Generative.openai(
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str) -> str:
|
||||||
if not WEAVIATE_AVAILABLE:
|
if not WEAVIATE_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
Reference in New Issue
Block a user