diff --git a/src/crewai_tools/tools/selenium_scraping_tool/README.md b/src/crewai_tools/tools/selenium_scraping_tool/README.md index 631fcfe0e..e2ddefba1 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/README.md +++ b/src/crewai_tools/tools/selenium_scraping_tool/README.md @@ -31,3 +31,4 @@ tool = SeleniumScrapingTool(website_url='https://example.com', css_element='.mai - `css_element`: Mandatory. The CSS selector for a specific element to scrape from the website. - `cookie`: Optional. A dictionary containing cookie information. This parameter allows the tool to simulate a session with cookie information, providing access to content that may be restricted to logged-in users. - `wait_time`: Optional. The number of seconds the tool waits after loading the website and after setting a cookie, before scraping the content. This allows for dynamic content to load properly. +- `return_html`: Optional. If True, the tool returns HTML content. If False, the tool returns text content. diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index 47910f35b..5f7d9391b 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -11,8 +11,6 @@ from selenium.webdriver.common.by import By class FixedSeleniumScrapingToolSchema(BaseModel): """Input for SeleniumScrapingTool.""" - pass - class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema): """Input for SeleniumScrapingTool.""" @@ -33,6 +31,7 @@ class SeleniumScrapingTool(BaseTool): cookie: Optional[dict] = None wait_time: Optional[int] = 3 css_element: Optional[str] = None + return_html: Optional[bool] = False def __init__( self, @@ -63,18 +62,46 @@ class SeleniumScrapingTool(BaseTool): ) -> Any: website_url = kwargs.get("website_url", self.website_url) css_element = kwargs.get("css_element", self.css_element) + return_html = kwargs.get("return_html", self.return_html) driver = self._create_driver(website_url, self.cookie, self.wait_time) - content = [] - if css_element is None or css_element.strip() == "": - body_text = driver.find_element(By.TAG_NAME, "body").text - content.append(body_text) - else: - for element in driver.find_elements(By.CSS_SELECTOR, css_element): - content.append(element.text) + content = self._get_content(driver, css_element, return_html) driver.close() + return "\n".join(content) + def _get_content(self, driver, css_element, return_html): + content = [] + + if self._is_css_element_empty(css_element): + content.append(self._get_body_content(driver, return_html)) + else: + content.extend(self._get_elements_content(driver, css_element, return_html)) + + return content + + def _is_css_element_empty(self, css_element): + return css_element is None or css_element.strip() == "" + + def _get_body_content(self, driver, return_html): + body_element = driver.find_element(By.TAG_NAME, "body") + + return ( + body_element.get_attribute("outerHTML") + if return_html + else body_element.text + ) + + def _get_elements_content(self, driver, css_element, return_html): + elements_content = [] + + for element in driver.find_elements(By.CSS_SELECTOR, css_element): + elements_content.append( + element.get_attribute("outerHTML") if return_html else element.text + ) + + return elements_content + def _create_driver(self, url, cookie, wait_time): options = Options() options.add_argument("--headless") diff --git a/tests/tools/selenium_scraping_tool_test.py b/tests/tools/selenium_scraping_tool_test.py new file mode 100644 index 000000000..271047449 --- /dev/null +++ b/tests/tools/selenium_scraping_tool_test.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock, patch + +from bs4 import BeautifulSoup + +from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import ( + SeleniumScrapingTool, +) + + +def mock_driver_with_html(html_content): + driver = MagicMock() + mock_element = MagicMock() + mock_element.get_attribute.return_value = html_content + bs = BeautifulSoup(html_content, "html.parser") + mock_element.text = bs.get_text() + + driver.find_elements.return_value = [mock_element] + driver.find_element.return_value = mock_element + + return driver + + +def initialize_tool_with(mock_driver): + tool = SeleniumScrapingTool() + tool.driver = MagicMock(return_value=mock_driver) + + return tool + + +def test_tool_initialization(): + tool = SeleniumScrapingTool() + + assert tool.website_url is None + assert tool.css_element is None + assert tool.cookie is None + assert tool.wait_time == 3 + assert tool.return_html is False + + +@patch("selenium.webdriver.Chrome") +def test_scrape_without_css_selector(_mocked_chrome_driver): + html_content = "