diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index 27f7db132..57211e64e 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -57,7 +57,6 @@ class SeleniumScrapingTool(BaseTool): wait_time: Optional[int] = 3 css_element: Optional[str] = None return_html: Optional[bool] = False - _options: Optional[dict] = None _by: Optional[Any] = None def __init__( @@ -91,8 +90,10 @@ class SeleniumScrapingTool(BaseTool): raise ImportError( "`selenium` and `webdriver-manager` package not found, please run `uv add selenium webdriver-manager`" ) - self.driver = webdriver.Chrome() - self._options = Options() + + options: Options = Options() + options.add_argument("--headless") + self.driver = webdriver.Chrome(options=options) self._by = By if cookie is not None: self.cookie = cookie @@ -116,28 +117,30 @@ class SeleniumScrapingTool(BaseTool): website_url = kwargs.get("website_url", self.website_url) css_element = kwargs.get("css_element", self.css_element) return_html = kwargs.get("return_html", self.return_html) - driver = self._create_driver(website_url, self.cookie, self.wait_time) + try: + self._make_request(website_url, self.cookie, self.wait_time) + content = self._get_content(css_element, return_html) + return "\n".join(content) + except Exception as e: + return f"Error scraping website: {str(e)}" + finally: + self.driver.close() - content = self._get_content(driver, css_element, return_html) - driver.close() - - return "\n".join(content) - - def _get_content(self, driver, css_element, return_html): + def _get_content(self, css_element, return_html): content = [] if self._is_css_element_empty(css_element): - content.append(self._get_body_content(driver, return_html)) + content.append(self._get_body_content(return_html)) else: - content.extend(self._get_elements_content(driver, css_element, return_html)) + content.extend(self._get_elements_content(css_element, return_html)) return content def _is_css_element_empty(self, css_element): return css_element is None or css_element.strip() == "" - def _get_body_content(self, driver, return_html): - body_element = driver.find_element(self._by.TAG_NAME, "body") + def _get_body_content(self, return_html): + body_element = self.driver.find_element(self._by.TAG_NAME, "body") return ( body_element.get_attribute("outerHTML") @@ -145,17 +148,17 @@ class SeleniumScrapingTool(BaseTool): else body_element.text ) - def _get_elements_content(self, driver, css_element, return_html): + def _get_elements_content(self, css_element, return_html): elements_content = [] - for element in driver.find_elements(self._by.CSS_SELECTOR, css_element): + for element in self.driver.find_elements(self._by.CSS_SELECTOR, css_element): elements_content.append( element.get_attribute("outerHTML") if return_html else element.text ) return elements_content - def _create_driver(self, url, cookie, wait_time): + def _make_request(self, url, cookie, wait_time): if not url: raise ValueError("URL cannot be empty") @@ -163,17 +166,13 @@ class SeleniumScrapingTool(BaseTool): if not re.match(r"^https?://", url): raise ValueError("URL must start with http:// or https://") - options = self._options - options.add_argument("--headless") - driver = self.driver(options=options) - driver.get(url) + self.driver.get(url) time.sleep(wait_time) if cookie: - driver.add_cookie(cookie) + self.driver.add_cookie(cookie) time.sleep(wait_time) - driver.get(url) + self.driver.get(url) time.sleep(wait_time) - return driver def close(self): self.driver.close() diff --git a/tests/tools/selenium_scraping_tool_test.py b/tests/tools/selenium_scraping_tool_test.py index 4e0b890b5..b360df3a1 100644 --- a/tests/tools/selenium_scraping_tool_test.py +++ b/tests/tools/selenium_scraping_tool_test.py @@ -1,7 +1,8 @@ -from unittest.mock import MagicMock, patch -import tempfile import os +import tempfile +from unittest.mock import MagicMock, patch +import pytest from bs4 import BeautifulSoup from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import ( @@ -24,7 +25,7 @@ def mock_driver_with_html(html_content): def initialize_tool_with(mock_driver): tool = SeleniumScrapingTool() - tool.driver = MagicMock(return_value=mock_driver) + tool.driver = mock_driver return tool @@ -33,7 +34,7 @@ def initialize_tool_with(mock_driver): def test_tool_initialization(mocked_chrome): temp_dir = tempfile.mkdtemp() mocked_chrome.return_value = MagicMock() - + tool = SeleniumScrapingTool() assert tool.website_url is None @@ -41,7 +42,7 @@ def test_tool_initialization(mocked_chrome): assert tool.cookie is None assert tool.wait_time == 3 assert tool.return_html is False - + try: os.rmdir(temp_dir) except: @@ -102,3 +103,13 @@ def test_scrape_with_return_html_false(_mocked_chrome_driver): mock_driver.get.assert_called_once_with("https://example.com") mock_driver.find_element.assert_called_with("tag name", "body") mock_driver.close.assert_called_once() + + +@patch("selenium.webdriver.Chrome") +def test_scrape_with_driver_error(_mocked_chrome_driver): + mock_driver = MagicMock() + mock_driver.find_element.side_effect = Exception("WebDriver error occurred") + tool = initialize_tool_with(mock_driver) + result = tool._run(website_url="https://example.com") + assert result == "Error scraping website: WebDriver error occurred" + mock_driver.close.assert_called_once()