mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
Compare commits
1 Commits
devin/1756
...
devin/1755
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f33b201d |
137
docs/en/observability/confident-ai.mdx
Normal file
137
docs/en/observability/confident-ai.mdx
Normal file
@@ -0,0 +1,137 @@
|
||||
---
|
||||
title: Confident AI Integration
|
||||
description: Monitor and evaluate your CrewAI agents with Confident AI's comprehensive evaluation platform powered by DeepEval.
|
||||
icon: shield-check
|
||||
---
|
||||
|
||||
# Confident AI Overview
|
||||
|
||||
[Confident AI](https://confident-ai.com) is a comprehensive evaluation platform for LLM applications, powered by [DeepEval](https://github.com/confident-ai/deepeval). It provides advanced monitoring, evaluation, and optimization capabilities specifically designed for AI agent workflows.
|
||||
|
||||
Confident AI offers both tracing capabilities to monitor your agents in real-time and evaluation tools to assess the quality, safety, and performance of your CrewAI applications.
|
||||
|
||||
### Features
|
||||
|
||||
- **Real-time Monitoring**: Track agent interactions, task execution, and performance metrics
|
||||
- **Comprehensive Evaluation**: Assess output quality, relevance, safety, and consistency
|
||||
- **Cost Tracking**: Monitor LLM API usage and associated costs across your crews
|
||||
- **Safety & Compliance**: Detect potential issues like bias, toxicity, and PII leaks
|
||||
- **Performance Analytics**: Analyze execution times, success rates, and bottlenecks
|
||||
- **Custom Metrics**: Define and track domain-specific evaluation criteria
|
||||
- **Team Collaboration**: Share insights and collaborate on agent optimization
|
||||
|
||||
## Setup Instructions
|
||||
|
||||
<Steps>
|
||||
<Step title="Install Dependencies">
|
||||
```shell
|
||||
pip install deepeval crewai
|
||||
```
|
||||
</Step>
|
||||
<Step title="Get API Key">
|
||||
1. Sign up at [Confident AI](https://confident-ai.com)
|
||||
2. Navigate to your project settings
|
||||
3. Copy your API key
|
||||
</Step>
|
||||
<Step title="Configure CrewAI">
|
||||
Instrument CrewAI with your Confident API key using `instrument_crewai`:
|
||||
|
||||
```python
|
||||
from crewai import Task, Crew, Agent
|
||||
from deepeval.integrations.crewai import instrument_crewai
|
||||
|
||||
instrument_crewai()
|
||||
|
||||
agent = Agent(
|
||||
role="Consultant",
|
||||
goal="Write clear, concise explanation.",
|
||||
backstory="An expert consultant with a keen eye for software trends.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Explain the importance of {topic}",
|
||||
expected_output="A clear and concise explanation of the topic.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
result = crew.kickoff(inputs={"topic": "AI"})
|
||||
```
|
||||
</Step>
|
||||
<Step title="Add Evaluation (Optional)">
|
||||
For comprehensive evaluation of your crew's outputs:
|
||||
|
||||
```python
|
||||
from deepeval import evaluate
|
||||
from deepeval.metrics import AnswerRelevancyMetric, FaithfulnessMetric
|
||||
from deepeval.test_case import LLMTestCase
|
||||
|
||||
# Define evaluation metrics
|
||||
relevancy_metric = AnswerRelevancyMetric(threshold=0.7)
|
||||
faithfulness_metric = FaithfulnessMetric(threshold=0.8)
|
||||
|
||||
# Execute crew
|
||||
result = crew.kickoff(inputs={"topic": "artificial intelligence"})
|
||||
|
||||
# Create test case for evaluation
|
||||
test_case = LLMTestCase(
|
||||
input="Explain the importance of artificial intelligence",
|
||||
actual_output=str(result),
|
||||
expected_output="A comprehensive explanation of AI's significance"
|
||||
)
|
||||
|
||||
# Evaluate the output
|
||||
evaluate([test_case], [relevancy_metric, faithfulness_metric])
|
||||
```
|
||||
</Step>
|
||||
<Step title="View Results">
|
||||
After running your CrewAI application with Confident AI integration:
|
||||
|
||||
1. Visit your [Confident AI dashboard](https://confident-ai.com/dashboard)
|
||||
2. Navigate to your project to view traces and evaluations
|
||||
3. Analyze agent performance, costs, and quality metrics
|
||||
4. Set up alerts for performance thresholds or quality issues
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## Key Metrics Tracked
|
||||
|
||||
### Performance Metrics
|
||||
- **Execution Time**: Duration of individual tasks and overall crew execution
|
||||
- **Token Usage**: Input/output tokens consumed by each agent
|
||||
- **API Latency**: Response times from LLM providers
|
||||
- **Success Rate**: Percentage of successfully completed tasks
|
||||
|
||||
### Quality Metrics
|
||||
- **Answer Relevancy**: How well outputs address the given tasks
|
||||
- **Faithfulness**: Accuracy and consistency of agent responses
|
||||
- **Coherence**: Logical flow and structure of outputs
|
||||
- **Safety**: Detection of harmful or inappropriate content
|
||||
|
||||
### Cost Metrics
|
||||
- **API Costs**: Real-time tracking of LLM usage costs
|
||||
- **Cost per Task**: Economic efficiency analysis
|
||||
- **Budget Monitoring**: Alerts for spending thresholds
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Development Phase
|
||||
- Start with basic tracing to understand agent behavior
|
||||
- Implement evaluation metrics early in development
|
||||
- Use custom metrics for domain-specific requirements
|
||||
- Monitor resource usage during testing
|
||||
|
||||
### Production Phase
|
||||
- Set up comprehensive monitoring and alerting
|
||||
- Track performance trends over time
|
||||
- Implement automated quality checks
|
||||
- Maintain cost visibility and control
|
||||
|
||||
### Continuous Improvement
|
||||
- Regular performance reviews using Confident AI analytics
|
||||
- A/B testing of different agent configurations
|
||||
- Feedback loops for quality improvement
|
||||
- Documentation of optimization insights
|
||||
|
||||
For more detailed information and advanced configurations, visit the [Confident AI documentation](https://confident-ai.com/docs) and [DeepEval documentation](https://docs.deepeval.com/).
|
||||
@@ -57,6 +57,10 @@ Observability is crucial for understanding how your CrewAI agents perform, ident
|
||||
<Card title="Weave" icon="network-wired" href="/en/observability/weave">
|
||||
Weights & Biases platform for tracking and evaluating AI applications.
|
||||
</Card>
|
||||
|
||||
<Card title="Confident AI" icon="shield-check" href="/en/observability/confident-ai">
|
||||
Comprehensive evaluation platform powered by DeepEval for monitoring and optimizing agent performance.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
### Evaluation & Quality Assurance
|
||||
|
||||
@@ -23,7 +23,7 @@ dependencies = [
|
||||
# Data Handling
|
||||
"chromadb>=0.5.23",
|
||||
"tokenizers>=0.20.3",
|
||||
"onnxruntime>=1.22.1",
|
||||
"onnxruntime==1.22.0",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
@@ -68,9 +68,6 @@ docling = [
|
||||
aisuite = [
|
||||
"aisuite>=0.1.10",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
|
||||
@@ -7,8 +7,7 @@ from rich.console import Console
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from .utils import TokenManager, validate_jwt_token
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
@@ -22,19 +21,10 @@ console = Console()
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
@@ -54,15 +44,11 @@ class ProviderFactory:
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
@@ -79,7 +65,7 @@ class AuthenticationCommand:
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE,
|
||||
audience=AUTH0_AUDIENCE
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
@@ -89,7 +75,9 @@ class AuthenticationCommand:
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
def _get_device_code(
|
||||
self
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
@@ -98,9 +86,7 @@ class AuthenticationCommand:
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -111,7 +97,9 @@ class AuthenticationCommand:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
@@ -124,9 +112,7 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from .utils import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
@@ -60,3 +67,118 @@ def validate_jwt_token(
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -11,7 +11,6 @@ from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
@@ -54,7 +53,6 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
@@ -76,12 +74,12 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
@@ -91,7 +89,7 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
@@ -118,7 +116,6 @@ class Settings(BaseModel):
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
@@ -142,7 +139,3 @@ class Settings(BaseModel):
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
TokenManager().clear_tokens()
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Core exceptions for RAG module."""
|
||||
|
||||
|
||||
class ClientMethodMismatchError(TypeError):
|
||||
"""Raised when a method is called with the wrong client type.
|
||||
|
||||
Typically used when a sync method is called with an async client,
|
||||
or vice versa.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
|
||||
) -> None:
|
||||
"""Create a ClientMethodMismatchError.
|
||||
|
||||
Args:
|
||||
method_name: Method that was called incorrectly.
|
||||
expected_client: Required client type.
|
||||
alt_method: Suggested alternative method.
|
||||
alt_client: Client type for the alternative method.
|
||||
"""
|
||||
message = (
|
||||
f"Method {method_name}() requires a {expected_client}. "
|
||||
f"Use {alt_method}() for {alt_client}."
|
||||
)
|
||||
super().__init__(message)
|
||||
@@ -1 +0,0 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
@@ -1,527 +0,0 @@
|
||||
"""Qdrant client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
QdrantClientParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class QdrantClient(BaseClient):
|
||||
"""Qdrant implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Qdrant, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
client: QdrantClientType
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType | None = None,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
|
||||
**kwargs: Unpack[QdrantClientParams],
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with optional client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Optional pre-configured Qdrant client instance.
|
||||
embedding_function: Optional embedding function. If not provided,
|
||||
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
|
||||
**kwargs: Additional arguments for QdrantClient creation.
|
||||
"""
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
location = kwargs.get("location", ":memory:")
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
|
||||
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
|
||||
|
||||
if embedding_function is not None:
|
||||
self.embedding_function = embedding_function
|
||||
else:
|
||||
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
|
||||
|
||||
def _embed_fn(text: str) -> list[float]:
|
||||
embeddings = list(_embedder.embed([text]))
|
||||
return [float(x) for x in embeddings[0]] if embeddings else []
|
||||
|
||||
self.embedding_function = _embed_fn
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in Qdrant asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points, wait=True)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="add_documents",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(
|
||||
collection_name=collection_name, points=points, wait=True
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="search",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = await self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="delete_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
await self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collections_response = self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
self.client.delete_collection(collection_name=collection.name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="reset",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collections_response = await self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
await self.client.delete_collection(collection_name=collection.name)
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Constants for Qdrant implementation."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||
@@ -1,134 +0,0 @@
|
||||
"""Type definitions specific to Qdrant implementation."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
import numpy as np
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
HasVectorCondition,
|
||||
HnswConfigDiff,
|
||||
InitFrom,
|
||||
IsEmptyCondition,
|
||||
IsNullCondition,
|
||||
NestedCondition,
|
||||
OptimizersConfigDiff,
|
||||
QuantizationConfig,
|
||||
ShardingMethod,
|
||||
SparseVectorsConfig,
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
|
||||
|
||||
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
|
||||
|
||||
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
|
||||
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
|
||||
FilterCondition = BasicConditions | StructuralConditions | Filter
|
||||
|
||||
MetadataFilterValue = bool | int | str
|
||||
MetadataFilter = dict[str, MetadataFilterValue]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantClientParams(TypedDict, total=False):
|
||||
"""Parameters for QdrantClient initialization."""
|
||||
|
||||
location: str | None
|
||||
url: str | None
|
||||
port: int
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
https: bool | None
|
||||
api_key: str | None
|
||||
prefix: str | None
|
||||
timeout: int | None
|
||||
host: str | None
|
||||
path: str | None
|
||||
force_disable_check_same_thread: bool
|
||||
grpc_options: dict[str, Any] | None
|
||||
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
|
||||
cloud_inference: bool
|
||||
local_inference_batch_size: int | None
|
||||
check_compatibility: bool
|
||||
|
||||
|
||||
class CommonCreateFields(TypedDict, total=False):
|
||||
"""Fields shared between high-level and direct create_collection params."""
|
||||
|
||||
vectors_config: VectorsConfig
|
||||
sparse_vectors_config: SparseVectorsConfig
|
||||
shard_number: Annotated[int, "Number of shards (default: 1)"]
|
||||
sharding_method: ShardingMethod
|
||||
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
|
||||
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
|
||||
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
|
||||
hnsw_config: HnswConfigDiff
|
||||
optimizers_config: OptimizersConfigDiff
|
||||
wal_config: WalConfigDiff
|
||||
quantization_config: QuantizationConfig
|
||||
init_from: InitFrom | str
|
||||
timeout: Annotated[int, "Operation timeout in seconds"]
|
||||
|
||||
|
||||
class QdrantCollectionCreateParams(
|
||||
BaseCollectionParams, CommonCreateFields, total=False
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
collection_name: str
|
||||
|
||||
|
||||
class PreparedSearchParams(TypedDict):
|
||||
"""Type definition for prepared Qdrant search parameters."""
|
||||
|
||||
collection_name: str
|
||||
query: list[float]
|
||||
limit: Annotated[int, "Max results to return"]
|
||||
with_payload: Annotated[bool, "Include payload in results"]
|
||||
with_vectors: Annotated[bool, "Include vectors in results"]
|
||||
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
|
||||
query_filter: NotRequired[Filter]
|
||||
@@ -1,228 +0,0 @@
|
||||
"""Utility functions for Qdrant operations."""
|
||||
|
||||
import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointStruct,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
CreateCollectionParams,
|
||||
EmbeddingFunction,
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
PreparedSearchParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import SearchResult, BaseRecord
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
"""Convert embedding to list[float] format if needed.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector as list or numpy array.
|
||||
|
||||
Returns:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
return embedding.tolist()
|
||||
return embedding
|
||||
|
||||
|
||||
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
|
||||
"""Type guard to check if the client is a synchronous QdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a QdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, SyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncQdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async.
|
||||
|
||||
Args:
|
||||
func: The embedding function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
"""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _get_collection_params(
|
||||
kwargs: QdrantCollectionCreateParams,
|
||||
) -> CreateCollectionParams:
|
||||
"""Extract collection creation parameters from kwargs."""
|
||||
params: CreateCollectionParams = {
|
||||
"collection_name": kwargs["collection_name"],
|
||||
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
|
||||
}
|
||||
|
||||
if "sparse_vectors_config" in kwargs:
|
||||
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
|
||||
if "shard_number" in kwargs:
|
||||
params["shard_number"] = kwargs["shard_number"]
|
||||
if "sharding_method" in kwargs:
|
||||
params["sharding_method"] = kwargs["sharding_method"]
|
||||
if "replication_factor" in kwargs:
|
||||
params["replication_factor"] = kwargs["replication_factor"]
|
||||
if "write_consistency_factor" in kwargs:
|
||||
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
|
||||
if "on_disk_payload" in kwargs:
|
||||
params["on_disk_payload"] = kwargs["on_disk_payload"]
|
||||
if "hnsw_config" in kwargs:
|
||||
params["hnsw_config"] = kwargs["hnsw_config"]
|
||||
if "optimizers_config" in kwargs:
|
||||
params["optimizers_config"] = kwargs["optimizers_config"]
|
||||
if "wal_config" in kwargs:
|
||||
params["wal_config"] = kwargs["wal_config"]
|
||||
if "quantization_config" in kwargs:
|
||||
params["quantization_config"] = kwargs["quantization_config"]
|
||||
if "init_from" in kwargs:
|
||||
params["init_from"] = kwargs["init_from"]
|
||||
if "timeout" in kwargs:
|
||||
params["timeout"] = kwargs["timeout"]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_search_params(
|
||||
collection_name: str,
|
||||
query_embedding: QueryEmbedding,
|
||||
limit: int,
|
||||
score_threshold: float | None,
|
||||
metadata_filter: MetadataFilter | None,
|
||||
) -> PreparedSearchParams:
|
||||
"""Prepare search parameters for Qdrant query_points.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to search.
|
||||
query_embedding: Embedding vector for the query.
|
||||
limit: Maximum number of results.
|
||||
score_threshold: Optional minimum similarity score.
|
||||
metadata_filter: Optional metadata filters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for query_points method.
|
||||
"""
|
||||
query_vector = _ensure_list_embedding(query_embedding)
|
||||
|
||||
search_kwargs: PreparedSearchParams = {
|
||||
"collection_name": collection_name,
|
||||
"query": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"with_vectors": False,
|
||||
}
|
||||
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
if metadata_filter:
|
||||
filter_conditions: list[FilterCondition] = []
|
||||
for key, value in metadata_filter.items():
|
||||
filter_conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
|
||||
search_kwargs["query_filter"] = Filter(must=filter_conditions)
|
||||
|
||||
return search_kwargs
|
||||
|
||||
|
||||
def _normalize_qdrant_score(score: float) -> float:
|
||||
"""Normalize Qdrant cosine similarity score to [0, 1] range.
|
||||
|
||||
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
|
||||
|
||||
Args:
|
||||
score: Raw cosine similarity score from Qdrant [-1, 1].
|
||||
|
||||
Returns:
|
||||
Normalized score in [0, 1] range where 1 is most similar.
|
||||
"""
|
||||
normalized = (score + 1.0) / 2.0
|
||||
return max(0.0, min(1.0, normalized))
|
||||
|
||||
|
||||
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
|
||||
"""Process Qdrant search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Response from Qdrant query_points method.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results: list[SearchResult] = []
|
||||
for point in response.points:
|
||||
payload = point.payload or {}
|
||||
score = _normalize_qdrant_score(score=point.score)
|
||||
result: SearchResult = {
|
||||
"id": str(point.id),
|
||||
"content": payload.get("content", ""),
|
||||
"metadata": {k: v for k, v in payload.items() if k != "content"},
|
||||
"score": score,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _create_point_from_document(
|
||||
doc: BaseRecord, embedding: QueryEmbedding
|
||||
) -> PointStruct:
|
||||
"""Create a PointStruct from a document and its embedding.
|
||||
|
||||
Args:
|
||||
doc: Document dictionary containing content, metadata, and optional doc_id.
|
||||
embedding: The embedding vector for the document content.
|
||||
|
||||
Returns:
|
||||
PointStruct ready to be upserted to Qdrant.
|
||||
"""
|
||||
doc_id = doc.get("doc_id", str(uuid4()))
|
||||
vector = _ensure_list_embedding(embedding)
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
if isinstance(metadata, list):
|
||||
metadata = metadata[0] if metadata else {}
|
||||
elif not isinstance(metadata, dict):
|
||||
metadata = dict(metadata) if metadata else {}
|
||||
|
||||
return PointStruct(
|
||||
id=doc_id,
|
||||
vector=vector,
|
||||
payload={"content": doc["content"], **metadata},
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Import utilities for optional dependencies."""
|
||||
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
class OptionalDependencyError(ImportError):
|
||||
"""Exception raised when an optional dependency is not installed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def require(name: str, *, purpose: str) -> ModuleType:
|
||||
"""Import a module, raising a helpful error if it's not installed.
|
||||
|
||||
Args:
|
||||
name: The module name to import.
|
||||
purpose: Description of what requires this dependency.
|
||||
|
||||
Returns:
|
||||
The imported module.
|
||||
|
||||
Raises:
|
||||
OptionalDependencyError: If the module is not installed.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
except ImportError as exc:
|
||||
raise OptionalDependencyError(
|
||||
f"{purpose} requires the optional dependency '{name}'.\n"
|
||||
f"Install it with: uv add {name}"
|
||||
) from exc
|
||||
@@ -1,14 +1,17 @@
|
||||
import json
|
||||
import jwt
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
from crewai.cli.authentication.utils import TokenManager, validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai.cli.authentication.utils.jwt")
|
||||
class TestUtils(unittest.TestCase):
|
||||
class TestValidateToken(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
@@ -102,3 +105,121 @@ class TestUtils(unittest.TestCase):
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
||||
mock_key = b"new_key"
|
||||
mock_read.return_value = None
|
||||
mock_generate.return_value = mock_key
|
||||
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_once_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_save.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
def test_save_tokens(self, mock_save):
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
args = mock_save.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
def test_get_token_valid(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
def test_get_token_expired(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
||||
@patch("crewai.cli.authentication.utils.os.chmod")
|
||||
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
filename = "test_file.txt"
|
||||
content = b"test_content"
|
||||
|
||||
self.token_manager.save_secure_file(filename, content)
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
||||
mock_open().write.assert_called_once_with(content)
|
||||
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
@patch(
|
||||
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
||||
)
|
||||
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = True
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertEqual(result, b"test_content")
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
def test_read_secure_file_not_exists(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = False
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertIsNone(result)
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
|
||||
@@ -3,7 +3,6 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.cli.config import (
|
||||
Settings,
|
||||
@@ -11,8 +10,6 @@ from crewai.cli.config import (
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
@@ -69,8 +66,7 @@ class TestSettings(unittest.TestCase):
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
def test_reset_settings(self):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
|
||||
@@ -78,11 +74,6 @@ class TestSettings(unittest.TestCase):
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
)
|
||||
|
||||
mock_token_manager.return_value = MagicMock()
|
||||
TokenManager().save_tokens(
|
||||
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
|
||||
settings.reset()
|
||||
|
||||
for key in user_settings.keys():
|
||||
@@ -90,8 +81,6 @@ class TestSettings(unittest.TestCase):
|
||||
for key in cli_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
mock_token_manager.return_value.clear_tokens.assert_called_once()
|
||||
|
||||
def test_dump_new_settings(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
import json
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
||||
mock_key = b"new_key"
|
||||
mock_read.return_value = None
|
||||
mock_generate.return_value = mock_key
|
||||
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_once_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_save.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
def test_save_tokens(self, mock_save):
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
args = mock_save.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
def test_get_token_valid(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
def test_get_token_expired(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
||||
@patch("crewai.cli.shared.token_manager.os.chmod")
|
||||
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
filename = "test_file.txt"
|
||||
content = b"test_content"
|
||||
|
||||
self.token_manager.save_secure_file(filename, content)
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
||||
mock_open().write.assert_called_once_with(content)
|
||||
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
@patch(
|
||||
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
||||
)
|
||||
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = True
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertEqual(result, b"test_content")
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
def test_read_secure_file_not_exists(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = False
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertIsNone(result)
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
def test_clear_tokens(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
self.token_manager.clear_tokens()
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with("tokens.enc")
|
||||
mock_path.__truediv__.return_value.unlink.assert_called_once_with(
|
||||
missing_ok=True
|
||||
)
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
|
||||
@@ -1,793 +0,0 @@
|
||||
"""Tests for QdrantClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Create a mock Qdrant client."""
|
||||
return Mock(spec=SyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_qdrant_client():
|
||||
"""Create a mock async Qdrant client."""
|
||||
return Mock(spec=AsyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=mock_embedding)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=mock_embedding
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestQdrantClient:
|
||||
"""Test suite for QdrantClient."""
|
||||
|
||||
def test_create_collection(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection creates a new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
def test_create_collection_already_exists(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection raises error if collection exists."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
def test_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method create_collection\(\) requires"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that acreate_collection creates a new collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_async_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_already_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that acreate_collection raises error if collection exists."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that acreate_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method acreate_collection\(\) requires"
|
||||
):
|
||||
await client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
def test_get_or_create_collection_existing(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection returns existing collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_new(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection creates new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test get_or_create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method get_or_create_collection\(\) requires",
|
||||
):
|
||||
client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_existing(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection returns existing collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_new(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection creates new collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_wrong_client_type(
|
||||
self, mock_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method aget_or_create_collection\(\) requires",
|
||||
):
|
||||
await client.aget_or_create_collection(collection_name="test_collection")
|
||||
|
||||
def test_add_documents(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["wait"] is True
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
def test_add_documents_with_doc_id(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents uses provided doc_id."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
def test_add_documents_empty_list(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that add_documents raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method add_documents\(\) requires"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_async_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["wait"] is True
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_doc_id(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents uses provided doc_id."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that aadd_documents raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method aadd_documents\(\) requires"
|
||||
):
|
||||
await client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
def test_search(self, client, mock_qdrant_client):
|
||||
"""Test that search returns matching documents."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("test query")
|
||||
mock_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
def test_search_with_filters(self, client, mock_qdrant_client):
|
||||
"""Test that search applies metadata filters correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
def test_search_with_options(self, client, mock_qdrant_client):
|
||||
"""Test that search applies limit and score_threshold correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["score_threshold"] == 0.8
|
||||
|
||||
def test_search_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that search raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_search_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that search raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method search\(\) requires"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch returns matching documents asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("test query")
|
||||
mock_async_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_with_filters(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch applies metadata filters correctly."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that asearch raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that asearch raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method asearch\(\) requires"
|
||||
):
|
||||
await client.asearch(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_delete_collection(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection deletes the collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
def test_delete_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_delete_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that delete_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method delete_collection\(\) requires"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that adelete_collection deletes the collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that adelete_collection raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that adelete_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method adelete_collection\(\) requires"
|
||||
):
|
||||
await client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
def test_reset(self, client, mock_qdrant_client):
|
||||
"""Test that reset deletes all collections."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_qdrant_client.delete_collection.call_count == 3
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
def test_reset_no_collections(self, client, mock_qdrant_client):
|
||||
"""Test that reset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_reset_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that reset raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method reset\(\) requires"
|
||||
):
|
||||
client.reset()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset deletes all collections asynchronously."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_async_qdrant_client.delete_collection.call_count == 3
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_no_collections(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that areset raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method areset\(\) requires"
|
||||
):
|
||||
await client.areset()
|
||||
@@ -1,39 +0,0 @@
|
||||
"""Test the embeddings factory functionality, particularly ONNX provider."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_onnx_embedding_function_creation():
|
||||
"""Test that ONNX embedding function can be created."""
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
|
||||
embedding_func = get_embedding_function({"provider": "onnx"})
|
||||
assert embedding_func is not None
|
||||
|
||||
|
||||
def test_onnx_embedding_function_basic_functionality():
|
||||
"""Test that ONNX embedding function can process text."""
|
||||
import numpy as np
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
|
||||
embedding_func = get_embedding_function({"provider": "onnx"})
|
||||
|
||||
result = embedding_func(["test text"])
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
assert isinstance(result[0], np.ndarray)
|
||||
assert len(result[0]) > 0
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx_provider_in_list():
|
||||
"""Test that onnx provider is available in the factory."""
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
|
||||
try:
|
||||
embedding_func = get_embedding_function({"provider": "onnx"})
|
||||
assert embedding_func is not None
|
||||
except ValueError as e:
|
||||
if "Unsupported provider" in str(e):
|
||||
pytest.fail("ONNX provider should be supported")
|
||||
else:
|
||||
raise
|
||||
@@ -13,12 +13,3 @@ def test_crew_output_import():
|
||||
from crewai import CrewOutput
|
||||
|
||||
assert CrewOutput is not None
|
||||
|
||||
|
||||
def test_onnxruntime_import_and_version():
|
||||
"""Test that onnxruntime can be imported and is version >= 1.22.1."""
|
||||
import onnxruntime
|
||||
from packaging import version
|
||||
|
||||
assert onnxruntime is not None
|
||||
assert version.parse(onnxruntime.__version__) >= version.parse("1.22.1")
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Tests for import utilities."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.utilities.import_utils import require, OptionalDependencyError
|
||||
|
||||
|
||||
class TestRequire:
|
||||
"""Test the require function."""
|
||||
|
||||
def test_require_existing_module(self):
|
||||
"""Test requiring a module that exists."""
|
||||
module = require("json", purpose="testing")
|
||||
assert module.__name__ == "json"
|
||||
|
||||
def test_require_missing_module(self):
|
||||
"""Test requiring a module that doesn't exist."""
|
||||
with pytest.raises(OptionalDependencyError) as exc_info:
|
||||
require("nonexistent_module_xyz", purpose="testing missing module")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert (
|
||||
"testing missing module requires the optional dependency 'nonexistent_module_xyz'"
|
||||
in error_msg
|
||||
)
|
||||
assert "uv add nonexistent_module_xyz" in error_msg
|
||||
|
||||
def test_require_with_import_error(self):
|
||||
"""Test that ImportError is properly chained."""
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
mock_import.side_effect = ImportError("Module import failed")
|
||||
|
||||
with pytest.raises(OptionalDependencyError) as exc_info:
|
||||
require("some_module", purpose="testing error handling")
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, ImportError)
|
||||
assert str(exc_info.value.__cause__) == "Module import failed"
|
||||
|
||||
def test_optional_dependency_error_is_import_error(self):
|
||||
"""Test that OptionalDependencyError is a subclass of ImportError."""
|
||||
assert issubclass(OptionalDependencyError, ImportError)
|
||||
Reference in New Issue
Block a user