mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: Add FAISS search tool
- Implement FAISSSearchTool for vector similarity search - Add comprehensive unit tests - Update documentation with usage examples - Add FAISS dependency Closes #2118 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
58
docs/tools/faiss_search_tool.mdx
Normal file
58
docs/tools/faiss_search_tool.mdx
Normal file
@@ -0,0 +1,58 @@
|
||||
# FAISS Search Tool
|
||||
|
||||
The FAISS Search Tool enables efficient vector similarity search using Facebook AI Similarity Search (FAISS).
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.tools import FAISSSearchTool
|
||||
|
||||
# Initialize tool
|
||||
search_tool = FAISSSearchTool(
|
||||
index_type="L2", # or "IP" for inner product
|
||||
dimension=384, # Match your embedder's dimension
|
||||
embedder_config={
|
||||
"provider": "fastembed",
|
||||
"model": "BAAI/bge-small-en-v1.5"
|
||||
}
|
||||
)
|
||||
|
||||
# Add documents
|
||||
search_tool.add_texts([
|
||||
"Document 1 content",
|
||||
"Document 2 content",
|
||||
# ...
|
||||
])
|
||||
|
||||
# Create agent with tool
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="Find relevant information",
|
||||
tools=[search_tool]
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| index_type | str | FAISS index type ("L2" or "IP") |
|
||||
| dimension | int | Embedding dimension |
|
||||
| embedder_config | dict | Embedder configuration |
|
||||
|
||||
## Parameters
|
||||
|
||||
### index_type
|
||||
- `"L2"`: Euclidean distance (default)
|
||||
- `"IP"`: Inner product similarity
|
||||
|
||||
### dimension
|
||||
Default is 384, which matches the BAAI/bge-small-en-v1.5 model. Adjust this to match your chosen embedder model's output dimension.
|
||||
|
||||
### embedder_config
|
||||
Configuration for the embedding model. Supports all CrewAI embedder providers:
|
||||
- fastembed (default)
|
||||
- openai
|
||||
- google
|
||||
- ollama
|
||||
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"faiss-cpu>=1.7.4",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from .base_tool import BaseTool, tool
|
||||
from .faiss_search_tool import FAISSSearchTool
|
||||
|
||||
__all__ = ["BaseTool", "tool", "FAISSSearchTool"]
|
||||
|
||||
64
src/crewai/tools/faiss_search_tool.py
Normal file
64
src/crewai/tools/faiss_search_tool.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
import faiss
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
|
||||
class FAISSSearchTool(BaseTool):
|
||||
name: str = "FAISS Search Tool"
|
||||
description: str = "Search through documents using FAISS vector similarity search"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_type: str = "L2",
|
||||
dimension: int = 384, # Default for BAAI/bge-small-en-v1.5
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedder_config = embedder_config
|
||||
self.dimension = dimension
|
||||
self.index = self._create_index(index_type)
|
||||
self.texts = []
|
||||
self._initialize_embedder()
|
||||
|
||||
def _create_index(self, index_type: str) -> faiss.Index:
|
||||
if index_type == "L2":
|
||||
return faiss.IndexFlatL2(self.dimension)
|
||||
elif index_type == "IP":
|
||||
return faiss.IndexFlatIP(self.dimension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported index type: {index_type}")
|
||||
|
||||
def _initialize_embedder(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 3,
|
||||
score_threshold: float = 0.6
|
||||
) -> List[Dict[str, Any]]:
|
||||
query_embedding = self.embedder.embed_text(query)
|
||||
D, I = self.index.search(
|
||||
np.array([query_embedding], dtype=np.float32),
|
||||
k
|
||||
)
|
||||
|
||||
results = []
|
||||
for i, (dist, idx) in enumerate(zip(D[0], I[0])):
|
||||
if idx < len(self.texts):
|
||||
score = 1.0 / (1.0 + dist) # Convert distance to similarity score
|
||||
if score >= score_threshold:
|
||||
results.append({
|
||||
"text": self.texts[idx],
|
||||
"score": score
|
||||
})
|
||||
return results
|
||||
|
||||
def add_texts(self, texts: List[str]) -> None:
|
||||
embeddings = self.embedder.embed_texts(texts)
|
||||
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||
self.texts.extend(texts)
|
||||
43
tests/tools/test_faiss_search_tool.py
Normal file
43
tests/tools/test_faiss_search_tool.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from crewai.tools import FAISSSearchTool
|
||||
|
||||
def test_faiss_search_tool_initialization():
|
||||
tool = FAISSSearchTool()
|
||||
assert tool.name == "FAISS Search Tool"
|
||||
assert tool.dimension == 384
|
||||
|
||||
def test_faiss_search_with_texts():
|
||||
tool = FAISSSearchTool()
|
||||
texts = [
|
||||
"The quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"A completely different text"
|
||||
]
|
||||
tool.add_texts(texts)
|
||||
|
||||
results = tool.run(
|
||||
query="quick fox",
|
||||
k=2,
|
||||
score_threshold=0.5
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0]["text"], str)
|
||||
assert isinstance(results[0]["score"], float)
|
||||
|
||||
def test_faiss_search_threshold_filtering():
|
||||
tool = FAISSSearchTool()
|
||||
texts = ["Text A", "Text B", "Text C"]
|
||||
tool.add_texts(texts)
|
||||
|
||||
results = tool.run(
|
||||
query="Something completely different",
|
||||
score_threshold=0.99 # High threshold
|
||||
)
|
||||
|
||||
assert len(results) == 0 # No results above threshold
|
||||
|
||||
def test_invalid_index_type():
|
||||
with pytest.raises(ValueError):
|
||||
FAISSSearchTool(index_type="INVALID")
|
||||
Reference in New Issue
Block a user