mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-30 02:28:13 +00:00
Compare commits
7 Commits
devin/1768
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95534de830 | ||
|
|
70d017f05d | ||
|
|
5b317f3eb3 | ||
|
|
c453a65a0a | ||
|
|
5e4f7df7dd | ||
|
|
6ce41e4f11 | ||
|
|
ecd16486c1 |
142
docs/tools/faiss_search_tool.mdx
Normal file
142
docs/tools/faiss_search_tool.mdx
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# FAISS Search Tool
|
||||||
|
|
||||||
|
The FAISS Search Tool enables efficient vector similarity search using Facebook AI Similarity Search (FAISS).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.tools import FAISSSearchTool
|
||||||
|
|
||||||
|
# Initialize tool
|
||||||
|
search_tool = FAISSSearchTool(
|
||||||
|
index_type: str = "L2", # or "IP" for inner product
|
||||||
|
dimension: int = 384, # Match your embedder's dimension
|
||||||
|
embedder_config: Dict[str, Any] = {
|
||||||
|
"provider": "fastembed",
|
||||||
|
"model": "BAAI/bge-small-en-v1.5"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add documents (with error handling)
|
||||||
|
try:
|
||||||
|
search_tool.add_texts([
|
||||||
|
"Document 1 content",
|
||||||
|
"Document 2 content",
|
||||||
|
# ...
|
||||||
|
])
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Failed to add documents: {e}")
|
||||||
|
|
||||||
|
# Add large document sets efficiently
|
||||||
|
try:
|
||||||
|
search_tool.add_texts_batch(
|
||||||
|
texts=["Doc 1", "Doc 2", ...], # Large list of documents
|
||||||
|
batch_size=1000 # Process in batches to manage memory
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Failed to add documents in batch: {e}")
|
||||||
|
|
||||||
|
# Search with error handling
|
||||||
|
try:
|
||||||
|
results = search_tool.run(
|
||||||
|
query="search query",
|
||||||
|
k=3, # Number of results
|
||||||
|
score_threshold=0.6 # Minimum similarity score
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
print(f"Text: {result['text']}, Score: {result['score']}")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Search failed: {e}")
|
||||||
|
|
||||||
|
# Create agent with tool
|
||||||
|
agent = Agent(
|
||||||
|
role="researcher",
|
||||||
|
goal="Find relevant information",
|
||||||
|
tools=[search_tool]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Parameter | Type | Description |
|
||||||
|
|-----------|------|-------------|
|
||||||
|
| index_type | str | FAISS index type ("L2" or "IP") |
|
||||||
|
| dimension | int | Embedding dimension |
|
||||||
|
| embedder_config | dict | Embedder configuration |
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
### index_type
|
||||||
|
- `"L2"`: Euclidean distance (default)
|
||||||
|
- `"IP"`: Inner product similarity
|
||||||
|
|
||||||
|
### dimension
|
||||||
|
Default is 384, which matches the BAAI/bge-small-en-v1.5 model. Adjust this to match your chosen embedder model's output dimension.
|
||||||
|
|
||||||
|
### embedder_config
|
||||||
|
Configuration for the embedding model. Supports all CrewAI embedder providers:
|
||||||
|
- fastembed (default)
|
||||||
|
- openai
|
||||||
|
- google
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The tool includes comprehensive error handling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Invalid index type
|
||||||
|
try:
|
||||||
|
tool = FAISSSearchTool(index_type="INVALID")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Invalid index type: {e}")
|
||||||
|
|
||||||
|
# Empty query
|
||||||
|
try:
|
||||||
|
results = tool.run(query="")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Invalid query: {e}") # "Query cannot be empty"
|
||||||
|
|
||||||
|
# Invalid k value
|
||||||
|
try:
|
||||||
|
results = tool.run(query="test", k=0)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Invalid k: {e}") # "k must be positive"
|
||||||
|
|
||||||
|
# Invalid score threshold
|
||||||
|
try:
|
||||||
|
results = tool.run(query="test", score_threshold=1.5)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Invalid threshold: {e}") # "score_threshold must be between 0 and 1"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Memory Management
|
||||||
|
For large document sets, use batch processing to manage memory efficiently:
|
||||||
|
```python
|
||||||
|
# Process documents in batches
|
||||||
|
tool.add_texts_batch(texts=large_document_list, batch_size=1000)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Index Management
|
||||||
|
Monitor and manage index size:
|
||||||
|
```python
|
||||||
|
# Check index size
|
||||||
|
print(f"Current index size: {tool.index_size}")
|
||||||
|
|
||||||
|
# Check if index is empty
|
||||||
|
if tool.is_empty:
|
||||||
|
print("Index is empty")
|
||||||
|
|
||||||
|
# Clear index if needed
|
||||||
|
tool.clear_index()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Metrics
|
||||||
|
The tool is optimized for performance:
|
||||||
|
- Search operations typically complete within 1 second for indices up to 1000 documents
|
||||||
|
- Batch processing helps manage memory for large document sets
|
||||||
|
- Input sanitization ensures query safety without significant overhead
|
||||||
@@ -37,6 +37,7 @@ dependencies = [
|
|||||||
"tomli>=2.0.2",
|
"tomli>=2.0.2",
|
||||||
"blinker>=1.9.0",
|
"blinker>=1.9.0",
|
||||||
"json5>=0.10.0",
|
"json5>=0.10.0",
|
||||||
|
"faiss-cpu>=1.7.4,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from .base_tool import BaseTool, tool
|
from .base_tool import BaseTool, tool
|
||||||
|
from .faiss_search_tool import FAISSSearchTool
|
||||||
|
|
||||||
|
__all__ = ["BaseTool", "tool", "FAISSSearchTool"]
|
||||||
|
|||||||
175
src/crewai/tools/faiss_search_tool.py
Normal file
175
src/crewai/tools/faiss_search_tool.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FAISSSearchTool(BaseTool):
|
||||||
|
"""FAISS vector similarity search tool for efficient document search."""
|
||||||
|
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
name: str = "FAISS Search Tool"
|
||||||
|
description: str = "Search through documents using FAISS vector similarity search"
|
||||||
|
embedder_config: Optional[Dict[str, Any]] = Field(default=None)
|
||||||
|
dimension: int = Field(default=384) # Default for BAAI/bge-small-en-v1.5
|
||||||
|
texts: List[str] = Field(default_factory=list)
|
||||||
|
index_type: str = Field(default="L2")
|
||||||
|
index: Any = Field(default=None) # FAISS index instance
|
||||||
|
embedder: Any = Field(default=None) # Embedder instance
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_type: str = "L2",
|
||||||
|
dimension: int = 384,
|
||||||
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize FAISS search tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_type: Type of FAISS index ("L2" or "IP")
|
||||||
|
dimension: Embedding dimension
|
||||||
|
embedder_config: Configuration for the embedder
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dimension = dimension
|
||||||
|
self.embedder_config = embedder_config
|
||||||
|
self.index_type = index_type
|
||||||
|
self.index = self._create_index(index_type)
|
||||||
|
self._initialize_embedder()
|
||||||
|
|
||||||
|
def _create_index(self, index_type: str) -> faiss.Index:
|
||||||
|
"""Create FAISS index of specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_type: Type of index ("L2" or "IP")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FAISS index instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If index_type is not supported
|
||||||
|
"""
|
||||||
|
if index_type == "L2":
|
||||||
|
return faiss.IndexFlatL2(self.dimension)
|
||||||
|
elif index_type == "IP":
|
||||||
|
return faiss.IndexFlatIP(self.dimension)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported index type: {index_type}")
|
||||||
|
|
||||||
|
def _initialize_embedder(self):
|
||||||
|
"""Initialize the embedder using the provided configuration."""
|
||||||
|
from crewai.knowledge.embedder.fastembed import FastEmbed
|
||||||
|
self.embedder = FastEmbed()
|
||||||
|
|
||||||
|
def _sanitize_query(self, query: str) -> str:
|
||||||
|
"""Remove potentially harmful characters from query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Input query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized query string
|
||||||
|
"""
|
||||||
|
return re.sub(r'[^\w\s]', '', query)
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 3,
|
||||||
|
score_threshold: float = 0.6
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for similar texts using FAISS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
k: Number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing matched texts and scores
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input parameters are invalid
|
||||||
|
"""
|
||||||
|
if not query.strip():
|
||||||
|
raise ValueError("Query cannot be empty")
|
||||||
|
if k < 1:
|
||||||
|
raise ValueError("k must be positive")
|
||||||
|
if not 0 <= score_threshold <= 1:
|
||||||
|
raise ValueError("score_threshold must be between 0 and 1")
|
||||||
|
|
||||||
|
logger.debug(f"Searching for query: {query} with k={k}")
|
||||||
|
query = self._sanitize_query(query)
|
||||||
|
query_embedding = self.embedder.embed_text(query)
|
||||||
|
|
||||||
|
D, I = self.index.search(
|
||||||
|
np.array([query_embedding], dtype=np.float32),
|
||||||
|
k
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, (dist, idx) in enumerate(zip(D[0], I[0])):
|
||||||
|
if idx < len(self.texts):
|
||||||
|
score = 1.0 / (1.0 + dist) # Convert distance to similarity score
|
||||||
|
if score >= score_threshold:
|
||||||
|
results.append({
|
||||||
|
"text": self.texts[idx],
|
||||||
|
"score": score
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
def add_texts(self, texts: List[str]) -> None:
|
||||||
|
"""Add texts to the search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to add
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If embedding or indexing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
embeddings = self.embedder.embed_texts(texts)
|
||||||
|
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||||
|
self.texts.extend(texts)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to add texts: {str(e)}")
|
||||||
|
|
||||||
|
def add_texts_batch(self, texts: List[str], batch_size: int = 1000) -> None:
|
||||||
|
"""Add texts in batches to prevent memory issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to add
|
||||||
|
batch_size: Size of each batch
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If batch_size is invalid
|
||||||
|
"""
|
||||||
|
if batch_size < 1:
|
||||||
|
raise ValueError("batch_size must be positive")
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
self.add_texts(batch)
|
||||||
|
|
||||||
|
def clear_index(self) -> None:
|
||||||
|
"""Clear the index and stored texts."""
|
||||||
|
self.index = self._create_index(self.index_type)
|
||||||
|
self.texts = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_size(self) -> int:
|
||||||
|
"""Return number of vectors in index."""
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""Check if index is empty."""
|
||||||
|
return len(self.texts) == 0
|
||||||
46
tests/tools/test_faiss_search_tool.py
Normal file
46
tests/tools/test_faiss_search_tool.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.tools import FAISSSearchTool
|
||||||
|
|
||||||
|
def test_faiss_search_tool_initialization():
|
||||||
|
tool = FAISSSearchTool()
|
||||||
|
assert tool.name == "FAISS Search Tool"
|
||||||
|
assert tool.dimension == 384
|
||||||
|
|
||||||
|
def test_faiss_search_with_texts():
|
||||||
|
tool = FAISSSearchTool()
|
||||||
|
texts = [
|
||||||
|
"The quick brown fox",
|
||||||
|
"jumps over the lazy dog",
|
||||||
|
"A completely different text"
|
||||||
|
]
|
||||||
|
tool.add_texts(texts)
|
||||||
|
|
||||||
|
results = tool.run(
|
||||||
|
query="quick fox",
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
assert isinstance(results[0]["text"], str)
|
||||||
|
assert isinstance(results[0]["score"], float)
|
||||||
|
|
||||||
|
def test_faiss_search_threshold_filtering():
|
||||||
|
tool = FAISSSearchTool()
|
||||||
|
texts = ["Text A", "Text B", "Text C"]
|
||||||
|
tool.add_texts(texts)
|
||||||
|
|
||||||
|
results = tool.run(
|
||||||
|
query="Something completely different",
|
||||||
|
score_threshold=0.99 # High threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 0 # No results above threshold
|
||||||
|
|
||||||
|
def test_invalid_index_type():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FAISSSearchTool(index_type="INVALID")
|
||||||
Reference in New Issue
Block a user