mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: attempt to make embedchain optional (#450)
* fix: attempt to make embedchain optional * fix: drop pydantic_settings dependency * fix: ensure the package is importable without any extra dependency After making embedchain option many packages were unstalled which caused errors in some tools due to failing import directives
This commit is contained in:
@@ -1,14 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: App
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class PDFEmbedchainAdapter(Adapter):
|
||||
embedchain_app: App
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
src: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
where = (
|
||||
{"app_id": self.embedchain_app.config.id, "source": self.src}
|
||||
|
||||
@@ -5,15 +5,19 @@ from typing import Any, Dict, Optional, Type
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class BrightDataConfig(BaseSettings):
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com"
|
||||
DEFAULT_TIMEOUT: int = 600
|
||||
DEFAULT_POLLING_INTERVAL: int = 1
|
||||
|
||||
class Config:
|
||||
env_prefix = "BRIGHTDATA_"
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
|
||||
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
|
||||
DEFAULT_POLLING_INTERVAL=int(os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1"))
|
||||
)
|
||||
class BrightDataDatasetToolException(Exception):
|
||||
"""Exception raised for custom error in the application."""
|
||||
|
||||
@@ -48,7 +52,7 @@ class BrightDataDatasetToolSchema(BaseModel):
|
||||
default=None, description="Additional params if any"
|
||||
)
|
||||
|
||||
config = BrightDataConfig()
|
||||
config = BrightDataConfig.from_env()
|
||||
|
||||
BRIGHTDATA_API_URL = config.API_URL
|
||||
timeout = config.DEFAULT_TIMEOUT
|
||||
|
||||
@@ -5,12 +5,15 @@ from typing import Any, Optional, Type
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class BrightDataConfig(BaseSettings):
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
class Config:
|
||||
env_prefix = "BRIGHTDATA_"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request")
|
||||
)
|
||||
|
||||
class BrightDataSearchToolSchema(BaseModel):
|
||||
"""
|
||||
@@ -73,7 +76,7 @@ class BrightDataSearchTool(BaseTool):
|
||||
name: str = "Bright Data SERP Search"
|
||||
description: str = "Tool to perform web search using Bright Data SERP API."
|
||||
args_schema: Type[BaseModel] = BrightDataSearchToolSchema
|
||||
_config = BrightDataConfig()
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
|
||||
@@ -4,12 +4,15 @@ from typing import Any, Optional, Type
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class BrightDataConfig(BaseSettings):
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
class Config:
|
||||
env_prefix = "BRIGHTDATA_"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request")
|
||||
)
|
||||
|
||||
class BrightDataUnlockerToolSchema(BaseModel):
|
||||
"""
|
||||
@@ -57,7 +60,7 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
name: str = "Bright Data Web Unlocker Scraping"
|
||||
description: str = "Tool to perform web scraping using Bright Data Web Unlocker"
|
||||
args_schema: Type[BaseModel] = BrightDataUnlockerToolSchema
|
||||
_config = BrightDataConfig()
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -37,6 +42,8 @@ class CodeDocsSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docs_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -37,6 +42,8 @@ class CSVSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, csv: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
try:
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -29,6 +34,8 @@ class DirectorySearchTool(RagTool):
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -43,6 +48,8 @@ class DOCXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Any
|
||||
|
||||
try:
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -37,7 +42,7 @@ class GithubSearchTool(RagTool):
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: GithubLoader | None = PrivateAttr(default=None)
|
||||
_loader: Any | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,6 +50,8 @@ class GithubSearchTool(RagTool):
|
||||
content_types: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@@ -37,6 +42,8 @@ class MDXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, mdx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
try:
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -22,6 +27,8 @@ class MySQLSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "mysql"
|
||||
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
|
||||
|
||||
@@ -2,8 +2,13 @@ from typing import Any, Type, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
SQLALCHEMY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SQLALCHEMY_AVAILABLE = False
|
||||
|
||||
|
||||
class NL2SQLToolInput(BaseModel):
|
||||
@@ -25,6 +30,9 @@ class NL2SQLTool(BaseTool):
|
||||
args_schema: Type[BaseModel] = NL2SQLToolInput
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError("sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`")
|
||||
|
||||
data = {}
|
||||
tables = self._fetch_available_tables()
|
||||
|
||||
@@ -58,6 +66,9 @@ class NL2SQLTool(BaseTool):
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> Union[list, str]:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError("sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`")
|
||||
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@@ -36,6 +41,8 @@ class PDFSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
try:
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -22,6 +27,8 @@ class PGSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "postgres"
|
||||
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
||||
|
||||
@@ -40,7 +40,11 @@ class RagTool(BaseTool):
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
try:
|
||||
from embedchain import App
|
||||
except ImportError:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
||||
|
||||
@@ -2,10 +2,15 @@ import os
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
BEAUTIFULSOUP_AVAILABLE = True
|
||||
except ImportError:
|
||||
BEAUTIFULSOUP_AVAILABLE = False
|
||||
|
||||
|
||||
class FixedScrapeElementFromWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
@@ -61,6 +66,9 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not BEAUTIFULSOUP_AVAILABLE:
|
||||
raise ImportError("beautifulsoup4 is not installed. Please install it with `pip install crewai-tools[beautifulsoup4]`")
|
||||
|
||||
website_url = kwargs.get("website_url", self.website_url)
|
||||
css_element = kwargs.get("css_element", self.css_element)
|
||||
page = requests.get(
|
||||
|
||||
@@ -3,7 +3,11 @@ import re
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
BEAUTIFULSOUP_AVAILABLE = True
|
||||
except ImportError:
|
||||
BEAUTIFULSOUP_AVAILABLE = False
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -40,6 +44,9 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if not BEAUTIFULSOUP_AVAILABLE:
|
||||
raise ImportError("beautifulsoup4 is not installed. Please install it with `pip install crewai-tools[beautifulsoup4]`")
|
||||
|
||||
if website_url is not None:
|
||||
self.website_url = website_url
|
||||
self.description = (
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -39,6 +44,8 @@ class WebsiteSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, website: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -39,6 +44,8 @@ class YoutubeVideoSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, youtube_video_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def _run(
|
||||
|
||||
Reference in New Issue
Block a user