mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Add type hints
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class SerpApiBaseTool(BaseTool):
|
||||
"""Base class for SerpApi functionality with shared capabilities."""
|
||||
|
||||
client: Optional[Any] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -14,7 +16,7 @@ class SerpApiBaseTool(BaseTool):
|
||||
from serpapi import Client
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`serpapi` package not found"
|
||||
"`serpapi` package not found, please install with `pip install serpapi`"
|
||||
)
|
||||
api_key = os.getenv("SERPAPI_API_KEY")
|
||||
if not api_key:
|
||||
@@ -23,7 +25,7 @@ class SerpApiBaseTool(BaseTool):
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _omit_fields(self, data, omit_patterns):
|
||||
def _omit_fields(self, data: Union[dict, list], omit_patterns: list[str]) -> None:
|
||||
if isinstance(data, dict):
|
||||
for field in list(data.keys()):
|
||||
if any(re.compile(p).match(field) for p in omit_patterns):
|
||||
@@ -34,4 +36,3 @@ class SerpApiBaseTool(BaseTool):
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._omit_fields(item, omit_patterns)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user