mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
refactor(selenium): improve driver management and add headless mode (#268)
- Refactor Selenium scraping tool to use single driver instance - Add headless mode configuration for Chrome - Improve error handling with try/finally - Simplify code structure and improve maintainability
This commit is contained in:
@@ -57,7 +57,6 @@ class SeleniumScrapingTool(BaseTool):
|
|||||||
wait_time: Optional[int] = 3
|
wait_time: Optional[int] = 3
|
||||||
css_element: Optional[str] = None
|
css_element: Optional[str] = None
|
||||||
return_html: Optional[bool] = False
|
return_html: Optional[bool] = False
|
||||||
_options: Optional[dict] = None
|
|
||||||
_by: Optional[Any] = None
|
_by: Optional[Any] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -91,8 +90,10 @@ class SeleniumScrapingTool(BaseTool):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"`selenium` and `webdriver-manager` package not found, please run `uv add selenium webdriver-manager`"
|
"`selenium` and `webdriver-manager` package not found, please run `uv add selenium webdriver-manager`"
|
||||||
)
|
)
|
||||||
self.driver = webdriver.Chrome()
|
|
||||||
self._options = Options()
|
options: Options = Options()
|
||||||
|
options.add_argument("--headless")
|
||||||
|
self.driver = webdriver.Chrome(options=options)
|
||||||
self._by = By
|
self._by = By
|
||||||
if cookie is not None:
|
if cookie is not None:
|
||||||
self.cookie = cookie
|
self.cookie = cookie
|
||||||
@@ -116,28 +117,30 @@ class SeleniumScrapingTool(BaseTool):
|
|||||||
website_url = kwargs.get("website_url", self.website_url)
|
website_url = kwargs.get("website_url", self.website_url)
|
||||||
css_element = kwargs.get("css_element", self.css_element)
|
css_element = kwargs.get("css_element", self.css_element)
|
||||||
return_html = kwargs.get("return_html", self.return_html)
|
return_html = kwargs.get("return_html", self.return_html)
|
||||||
driver = self._create_driver(website_url, self.cookie, self.wait_time)
|
try:
|
||||||
|
self._make_request(website_url, self.cookie, self.wait_time)
|
||||||
|
content = self._get_content(css_element, return_html)
|
||||||
|
return "\n".join(content)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error scraping website: {str(e)}"
|
||||||
|
finally:
|
||||||
|
self.driver.close()
|
||||||
|
|
||||||
content = self._get_content(driver, css_element, return_html)
|
def _get_content(self, css_element, return_html):
|
||||||
driver.close()
|
|
||||||
|
|
||||||
return "\n".join(content)
|
|
||||||
|
|
||||||
def _get_content(self, driver, css_element, return_html):
|
|
||||||
content = []
|
content = []
|
||||||
|
|
||||||
if self._is_css_element_empty(css_element):
|
if self._is_css_element_empty(css_element):
|
||||||
content.append(self._get_body_content(driver, return_html))
|
content.append(self._get_body_content(return_html))
|
||||||
else:
|
else:
|
||||||
content.extend(self._get_elements_content(driver, css_element, return_html))
|
content.extend(self._get_elements_content(css_element, return_html))
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def _is_css_element_empty(self, css_element):
|
def _is_css_element_empty(self, css_element):
|
||||||
return css_element is None or css_element.strip() == ""
|
return css_element is None or css_element.strip() == ""
|
||||||
|
|
||||||
def _get_body_content(self, driver, return_html):
|
def _get_body_content(self, return_html):
|
||||||
body_element = driver.find_element(self._by.TAG_NAME, "body")
|
body_element = self.driver.find_element(self._by.TAG_NAME, "body")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
body_element.get_attribute("outerHTML")
|
body_element.get_attribute("outerHTML")
|
||||||
@@ -145,17 +148,17 @@ class SeleniumScrapingTool(BaseTool):
|
|||||||
else body_element.text
|
else body_element.text
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_elements_content(self, driver, css_element, return_html):
|
def _get_elements_content(self, css_element, return_html):
|
||||||
elements_content = []
|
elements_content = []
|
||||||
|
|
||||||
for element in driver.find_elements(self._by.CSS_SELECTOR, css_element):
|
for element in self.driver.find_elements(self._by.CSS_SELECTOR, css_element):
|
||||||
elements_content.append(
|
elements_content.append(
|
||||||
element.get_attribute("outerHTML") if return_html else element.text
|
element.get_attribute("outerHTML") if return_html else element.text
|
||||||
)
|
)
|
||||||
|
|
||||||
return elements_content
|
return elements_content
|
||||||
|
|
||||||
def _create_driver(self, url, cookie, wait_time):
|
def _make_request(self, url, cookie, wait_time):
|
||||||
if not url:
|
if not url:
|
||||||
raise ValueError("URL cannot be empty")
|
raise ValueError("URL cannot be empty")
|
||||||
|
|
||||||
@@ -163,17 +166,13 @@ class SeleniumScrapingTool(BaseTool):
|
|||||||
if not re.match(r"^https?://", url):
|
if not re.match(r"^https?://", url):
|
||||||
raise ValueError("URL must start with http:// or https://")
|
raise ValueError("URL must start with http:// or https://")
|
||||||
|
|
||||||
options = self._options
|
self.driver.get(url)
|
||||||
options.add_argument("--headless")
|
|
||||||
driver = self.driver(options=options)
|
|
||||||
driver.get(url)
|
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
if cookie:
|
if cookie:
|
||||||
driver.add_cookie(cookie)
|
self.driver.add_cookie(cookie)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
driver.get(url)
|
self.driver.get(url)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
return driver
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.driver.close()
|
self.driver.close()
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
|
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
|
||||||
@@ -24,7 +25,7 @@ def mock_driver_with_html(html_content):
|
|||||||
|
|
||||||
def initialize_tool_with(mock_driver):
|
def initialize_tool_with(mock_driver):
|
||||||
tool = SeleniumScrapingTool()
|
tool = SeleniumScrapingTool()
|
||||||
tool.driver = MagicMock(return_value=mock_driver)
|
tool.driver = mock_driver
|
||||||
|
|
||||||
return tool
|
return tool
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ def initialize_tool_with(mock_driver):
|
|||||||
def test_tool_initialization(mocked_chrome):
|
def test_tool_initialization(mocked_chrome):
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
mocked_chrome.return_value = MagicMock()
|
mocked_chrome.return_value = MagicMock()
|
||||||
|
|
||||||
tool = SeleniumScrapingTool()
|
tool = SeleniumScrapingTool()
|
||||||
|
|
||||||
assert tool.website_url is None
|
assert tool.website_url is None
|
||||||
@@ -41,7 +42,7 @@ def test_tool_initialization(mocked_chrome):
|
|||||||
assert tool.cookie is None
|
assert tool.cookie is None
|
||||||
assert tool.wait_time == 3
|
assert tool.wait_time == 3
|
||||||
assert tool.return_html is False
|
assert tool.return_html is False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.rmdir(temp_dir)
|
os.rmdir(temp_dir)
|
||||||
except:
|
except:
|
||||||
@@ -102,3 +103,13 @@ def test_scrape_with_return_html_false(_mocked_chrome_driver):
|
|||||||
mock_driver.get.assert_called_once_with("https://example.com")
|
mock_driver.get.assert_called_once_with("https://example.com")
|
||||||
mock_driver.find_element.assert_called_with("tag name", "body")
|
mock_driver.find_element.assert_called_with("tag name", "body")
|
||||||
mock_driver.close.assert_called_once()
|
mock_driver.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("selenium.webdriver.Chrome")
|
||||||
|
def test_scrape_with_driver_error(_mocked_chrome_driver):
|
||||||
|
mock_driver = MagicMock()
|
||||||
|
mock_driver.find_element.side_effect = Exception("WebDriver error occurred")
|
||||||
|
tool = initialize_tool_with(mock_driver)
|
||||||
|
result = tool._run(website_url="https://example.com")
|
||||||
|
assert result == "Error scraping website: WebDriver error occurred"
|
||||||
|
mock_driver.close.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user