Add type hints

This commit is contained in:
Terry Tan Yongsheng
2024-12-17 13:45:50 +08:00
parent 2effe9a7d2
commit 81981e43b6

View File

@@ -1,10 +1,12 @@
import os import os
import re import re
from typing import Optional, Any from typing import Optional, Any, Union
from crewai.tools import BaseTool from crewai.tools import BaseTool
class SerpApiBaseTool(BaseTool): class SerpApiBaseTool(BaseTool):
"""Base class for SerpApi functionality with shared capabilities."""
client: Optional[Any] = None client: Optional[Any] = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -14,7 +16,7 @@ class SerpApiBaseTool(BaseTool):
from serpapi import Client from serpapi import Client
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"`serpapi` package not found" "`serpapi` package not found, please install with `pip install serpapi`"
) )
api_key = os.getenv("SERPAPI_API_KEY") api_key = os.getenv("SERPAPI_API_KEY")
if not api_key: if not api_key:
@@ -23,7 +25,7 @@ class SerpApiBaseTool(BaseTool):
) )
self.client = Client(api_key=api_key) self.client = Client(api_key=api_key)
def _omit_fields(self, data, omit_patterns): def _omit_fields(self, data: Union[dict, list], omit_patterns: list[str]) -> None:
if isinstance(data, dict): if isinstance(data, dict):
for field in list(data.keys()): for field in list(data.keys()):
if any(re.compile(p).match(field) for p in omit_patterns): if any(re.compile(p).match(field) for p in omit_patterns):
@@ -34,4 +36,3 @@ class SerpApiBaseTool(BaseTool):
elif isinstance(data, list): elif isinstance(data, list):
for item in data: for item in data:
self._omit_fields(item, omit_patterns) self._omit_fields(item, omit_patterns)