mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 13:28:13 +00:00
Compare commits
3 Commits
devin/1768
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fae812ffb7 | ||
|
|
17e3fcbe1f | ||
|
|
b858d705a8 |
@@ -574,6 +574,10 @@ When you run this Flow, the output will change based on the random boolean value
|
||||
|
||||
### Human in the Loop (human feedback)
|
||||
|
||||
<Note>
|
||||
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**.
|
||||
</Note>
|
||||
|
||||
The `@human_feedback` decorator enables human-in-the-loop workflows by pausing flow execution to collect feedback from a human. This is useful for approval gates, quality review, and decision points that require human judgment.
|
||||
|
||||
```python Code
|
||||
|
||||
@@ -7,6 +7,10 @@ mode: "wide"
|
||||
|
||||
## Overview
|
||||
|
||||
<Note>
|
||||
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**. Make sure to update your installation before using this feature.
|
||||
</Note>
|
||||
|
||||
The `@human_feedback` decorator enables human-in-the-loop (HITL) workflows directly within CrewAI Flows. It allows you to pause flow execution, present output to a human for review, collect their feedback, and optionally route to different listeners based on the feedback outcome.
|
||||
|
||||
This is particularly valuable for:
|
||||
|
||||
@@ -11,10 +11,10 @@ Human-in-the-Loop (HITL) is a powerful approach that combines artificial intelli
|
||||
|
||||
CrewAI offers two main approaches for implementing human-in-the-loop workflows:
|
||||
|
||||
| Approach | Best For | Integration |
|
||||
|----------|----------|-------------|
|
||||
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) |
|
||||
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide |
|
||||
| Approach | Best For | Integration | Version |
|
||||
|----------|----------|-------------|---------|
|
||||
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) | **1.8.0+** |
|
||||
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide | - |
|
||||
|
||||
<Tip>
|
||||
If you're building flows and want to add human review steps with routing based on feedback, check out the [Human Feedback in Flows](/en/learn/human-feedback-in-flows) guide for the `@human_feedback` decorator.
|
||||
|
||||
@@ -567,6 +567,10 @@ Fourth method running
|
||||
|
||||
### Human in the Loop (인간 피드백)
|
||||
|
||||
<Note>
|
||||
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다.
|
||||
</Note>
|
||||
|
||||
`@human_feedback` 데코레이터는 인간의 피드백을 수집하기 위해 플로우 실행을 일시 중지하는 human-in-the-loop 워크플로우를 가능하게 합니다. 이는 승인 게이트, 품질 검토, 인간의 판단이 필요한 결정 지점에 유용합니다.
|
||||
|
||||
```python Code
|
||||
|
||||
@@ -7,6 +7,10 @@ mode: "wide"
|
||||
|
||||
## 개요
|
||||
|
||||
<Note>
|
||||
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다. 이 기능을 사용하기 전에 설치를 업데이트하세요.
|
||||
</Note>
|
||||
|
||||
`@human_feedback` 데코레이터는 CrewAI Flow 내에서 직접 human-in-the-loop(HITL) 워크플로우를 가능하게 합니다. Flow 실행을 일시 중지하고, 인간에게 검토를 위해 출력을 제시하고, 피드백을 수집하고, 선택적으로 피드백 결과에 따라 다른 리스너로 라우팅할 수 있습니다.
|
||||
|
||||
이는 특히 다음과 같은 경우에 유용합니다:
|
||||
|
||||
@@ -5,9 +5,22 @@ icon: "user-check"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. 이 가이드에서는 CrewAI 내에서 HITL을 구현하는 방법을 안내합니다.
|
||||
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. CrewAI는 필요에 따라 HITL을 구현하는 여러 가지 방법을 제공합니다.
|
||||
|
||||
## HITL 워크플로우 설정
|
||||
## HITL 접근 방식 선택
|
||||
|
||||
CrewAI는 human-in-the-loop 워크플로우를 구현하기 위한 두 가지 주요 접근 방식을 제공합니다:
|
||||
|
||||
| 접근 방식 | 적합한 용도 | 통합 | 버전 |
|
||||
|----------|----------|-------------|---------|
|
||||
| **Flow 기반** (`@human_feedback` 데코레이터) | 로컬 개발, 콘솔 기반 검토, 동기식 워크플로우 | [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) | **1.8.0+** |
|
||||
| **Webhook 기반** (Enterprise) | 프로덕션 배포, 비동기 워크플로우, 외부 통합 (Slack, Teams 등) | 이 가이드 | - |
|
||||
|
||||
<Tip>
|
||||
Flow를 구축하면서 피드백을 기반으로 라우팅하는 인간 검토 단계를 추가하려면 `@human_feedback` 데코레이터에 대한 [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) 가이드를 참조하세요.
|
||||
</Tip>
|
||||
|
||||
## Webhook 기반 HITL 워크플로우 설정
|
||||
|
||||
<Steps>
|
||||
<Step title="작업 구성">
|
||||
|
||||
@@ -309,6 +309,10 @@ Ao executar esse Flow, a saída será diferente dependendo do valor booleano ale
|
||||
|
||||
### Human in the Loop (feedback humano)
|
||||
|
||||
<Note>
|
||||
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**.
|
||||
</Note>
|
||||
|
||||
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop, pausando a execução do flow para coletar feedback de um humano. Isso é útil para portões de aprovação, revisão de qualidade e pontos de decisão que requerem julgamento humano.
|
||||
|
||||
```python Code
|
||||
|
||||
@@ -7,6 +7,10 @@ mode: "wide"
|
||||
|
||||
## Visão Geral
|
||||
|
||||
<Note>
|
||||
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**. Certifique-se de atualizar sua instalação antes de usar este recurso.
|
||||
</Note>
|
||||
|
||||
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop (HITL) diretamente nos CrewAI Flows. Ele permite pausar a execução do flow, apresentar a saída para um humano revisar, coletar seu feedback e, opcionalmente, rotear para diferentes listeners com base no resultado do feedback.
|
||||
|
||||
Isso é particularmente valioso para:
|
||||
|
||||
@@ -5,9 +5,22 @@ icon: "user-check"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. Este guia mostra como implementar HITL dentro da CrewAI.
|
||||
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. CrewAI oferece várias maneiras de implementar HITL dependendo das suas necessidades.
|
||||
|
||||
## Configurando Workflows HITL
|
||||
## Escolhendo Sua Abordagem HITL
|
||||
|
||||
CrewAI oferece duas abordagens principais para implementar workflows human-in-the-loop:
|
||||
|
||||
| Abordagem | Melhor Para | Integração | Versão |
|
||||
|----------|----------|-------------|---------|
|
||||
| **Baseada em Flow** (decorador `@human_feedback`) | Desenvolvimento local, revisão via console, workflows síncronos | [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) | **1.8.0+** |
|
||||
| **Baseada em Webhook** (Enterprise) | Deployments em produção, workflows assíncronos, integrações externas (Slack, Teams, etc.) | Este guia | - |
|
||||
|
||||
<Tip>
|
||||
Se você está construindo flows e deseja adicionar etapas de revisão humana com roteamento baseado em feedback, confira o guia [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) para o decorador `@human_feedback`.
|
||||
</Tip>
|
||||
|
||||
## Configurando Workflows HITL Baseados em Webhook
|
||||
|
||||
<Steps>
|
||||
<Step title="Configure sua Tarefa">
|
||||
|
||||
@@ -209,10 +209,9 @@ class EventListener(BaseEventListener):
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
span = self.execution_spans.get(source)
|
||||
span = self.execution_spans.pop(source, None)
|
||||
if span:
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
self.execution_spans[source] = None
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = get_task_name(source)
|
||||
@@ -222,11 +221,10 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
|
||||
span = self.execution_spans.get(source)
|
||||
span = self.execution_spans.pop(source, None)
|
||||
if span:
|
||||
if source.agent and source.agent.crew:
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
self.execution_spans[source] = None
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = get_task_name(source)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar, tool
|
||||
from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"EnvVar",
|
||||
"SearchStrategy",
|
||||
"ToolSearchTool",
|
||||
"tool",
|
||||
]
|
||||
|
||||
333
lib/crewai/src/crewai/tools/tool_search_tool.py
Normal file
333
lib/crewai/src/crewai/tools/tool_search_tool.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tool Search Tool for on-demand tool discovery.
|
||||
|
||||
This module implements a Tool Search Tool that allows agents to dynamically
|
||||
discover and load tools on-demand, reducing token consumption when working
|
||||
with large tool libraries.
|
||||
|
||||
Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
|
||||
|
||||
class SearchStrategy(str, Enum):
|
||||
"""Search strategy for tool discovery."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
REGEX = "regex"
|
||||
|
||||
|
||||
class ToolSearchResult(BaseModel):
|
||||
"""Result from a tool search operation."""
|
||||
|
||||
name: str = Field(description="The name of the tool")
|
||||
description: str = Field(description="The description of the tool")
|
||||
args_schema: dict[str, Any] = Field(
|
||||
description="The JSON schema for the tool's arguments"
|
||||
)
|
||||
|
||||
|
||||
class ToolSearchToolSchema(BaseModel):
|
||||
"""Schema for the Tool Search Tool arguments."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query to find relevant tools. Use keywords that describe the capability you need."
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of tools to return. Default is 5.",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
|
||||
|
||||
class ToolSearchTool(BaseTool):
|
||||
"""A tool that searches through a catalog of tools to find relevant ones.
|
||||
|
||||
This tool enables on-demand tool discovery, allowing agents to work with
|
||||
large tool libraries without loading all tool definitions upfront. Instead
|
||||
of consuming tokens with all tool definitions, the agent can search for
|
||||
relevant tools when needed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai.tools import BaseTool, ToolSearchTool
|
||||
|
||||
# Create your tools
|
||||
search_tool = MySearchTool()
|
||||
scrape_tool = MyScrapeWebsiteTool()
|
||||
database_tool = MyDatabaseTool()
|
||||
|
||||
# Create a tool search tool with your tool catalog
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=[search_tool, scrape_tool, database_tool],
|
||||
search_strategy=SearchStrategy.KEYWORD,
|
||||
)
|
||||
|
||||
# Use with an agent - only the tool_search is loaded initially
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
tools=[tool_search], # Other tools discovered on-demand
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
tool_catalog: List of tools available for search.
|
||||
search_strategy: Strategy to use for searching (keyword or regex).
|
||||
custom_search_fn: Optional custom search function for advanced matching.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="Tool Search",
|
||||
description="The name of the tool search tool.",
|
||||
)
|
||||
description: str = Field(
|
||||
default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.",
|
||||
description="Description of what the tool search tool does.",
|
||||
)
|
||||
args_schema: type[BaseModel] = Field(
|
||||
default=ToolSearchToolSchema,
|
||||
description="The schema for the tool search arguments.",
|
||||
)
|
||||
tool_catalog: list[BaseTool | CrewStructuredTool] = Field(
|
||||
default_factory=list,
|
||||
description="List of tools available for search.",
|
||||
)
|
||||
search_strategy: SearchStrategy = Field(
|
||||
default=SearchStrategy.KEYWORD,
|
||||
description="Strategy to use for searching tools.",
|
||||
)
|
||||
custom_search_fn: Callable[
|
||||
[str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool]
|
||||
] | None = Field(
|
||||
default=None,
|
||||
description="Optional custom search function for advanced matching.",
|
||||
)
|
||||
|
||||
def _run(self, query: str, max_results: int = 5) -> str:
|
||||
"""Search for tools matching the query.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant tools.
|
||||
max_results: Maximum number of tools to return.
|
||||
|
||||
Returns:
|
||||
JSON string containing the matching tool definitions.
|
||||
"""
|
||||
if not self.tool_catalog:
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "No tools available in the catalog.",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
|
||||
if self.custom_search_fn:
|
||||
matching_tools = self.custom_search_fn(query, self.tool_catalog)
|
||||
elif self.search_strategy == SearchStrategy.REGEX:
|
||||
matching_tools = self._regex_search(query)
|
||||
else:
|
||||
matching_tools = self._keyword_search(query)
|
||||
|
||||
matching_tools = matching_tools[:max_results]
|
||||
|
||||
if not matching_tools:
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "no_results",
|
||||
"message": f"No tools found matching query: '{query}'. Try different keywords.",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
|
||||
tool_results = []
|
||||
for tool in matching_tools:
|
||||
tool_info = self._get_tool_info(tool)
|
||||
tool_results.append(tool_info)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "success",
|
||||
"message": f"Found {len(tool_results)} tool(s) matching your query.",
|
||||
"tools": tool_results,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
def _keyword_search(
|
||||
self, query: str
|
||||
) -> list[BaseTool | CrewStructuredTool]:
|
||||
"""Search tools using keyword matching.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
List of matching tools sorted by relevance.
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = []
|
||||
|
||||
for tool in self.tool_catalog:
|
||||
score = self._calculate_keyword_score(tool, query_lower, query_words)
|
||||
if score > 0:
|
||||
scored_tools.append((score, tool))
|
||||
|
||||
scored_tools.sort(key=lambda x: x[0], reverse=True)
|
||||
return [tool for _, tool in scored_tools]
|
||||
|
||||
def _calculate_keyword_score(
|
||||
self,
|
||||
tool: BaseTool | CrewStructuredTool,
|
||||
query_lower: str,
|
||||
query_words: set[str],
|
||||
) -> float:
|
||||
"""Calculate relevance score for a tool based on keyword matching.
|
||||
|
||||
Args:
|
||||
tool: The tool to score.
|
||||
query_lower: Lowercase query string.
|
||||
query_words: Set of query words.
|
||||
|
||||
Returns:
|
||||
Relevance score (higher is better).
|
||||
"""
|
||||
score = 0.0
|
||||
tool_name_lower = tool.name.lower()
|
||||
tool_desc_lower = tool.description.lower()
|
||||
|
||||
if query_lower in tool_name_lower:
|
||||
score += 10.0
|
||||
if query_lower in tool_desc_lower:
|
||||
score += 5.0
|
||||
|
||||
for word in query_words:
|
||||
if len(word) < 2:
|
||||
continue
|
||||
if word in tool_name_lower:
|
||||
score += 3.0
|
||||
if word in tool_desc_lower:
|
||||
score += 1.0
|
||||
|
||||
return score
|
||||
|
||||
def _regex_search(
|
||||
self, query: str
|
||||
) -> list[BaseTool | CrewStructuredTool]:
|
||||
"""Search tools using regex pattern matching.
|
||||
|
||||
Args:
|
||||
query: The regex pattern to search for.
|
||||
|
||||
Returns:
|
||||
List of matching tools.
|
||||
"""
|
||||
try:
|
||||
pattern = re.compile(query, re.IGNORECASE)
|
||||
except re.error:
|
||||
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in self.tool_catalog
|
||||
if pattern.search(tool.name) or pattern.search(tool.description)
|
||||
]
|
||||
|
||||
def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]:
|
||||
"""Get tool information as a dictionary.
|
||||
|
||||
Args:
|
||||
tool: The tool to get information from.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool name, description, and args schema.
|
||||
"""
|
||||
if isinstance(tool, BaseTool):
|
||||
schema_dict = generate_model_description(tool.args_schema)
|
||||
args_schema = schema_dict.get("json_schema", {}).get("schema", {})
|
||||
else:
|
||||
args_schema = tool.args_schema.model_json_schema()
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": self._get_original_description(tool),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str:
|
||||
"""Get the original description of a tool without the generated schema.
|
||||
|
||||
Args:
|
||||
tool: The tool to get the description from.
|
||||
|
||||
Returns:
|
||||
The original tool description.
|
||||
"""
|
||||
description = tool.description
|
||||
if "Tool Description:" in description:
|
||||
parts = description.split("Tool Description:")
|
||||
if len(parts) > 1:
|
||||
return parts[1].strip()
|
||||
return description
|
||||
|
||||
def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None:
|
||||
"""Add a tool to the catalog.
|
||||
|
||||
Args:
|
||||
tool: The tool to add.
|
||||
"""
|
||||
self.tool_catalog.append(tool)
|
||||
|
||||
def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None:
|
||||
"""Add multiple tools to the catalog.
|
||||
|
||||
Args:
|
||||
tools: The tools to add.
|
||||
"""
|
||||
self.tool_catalog.extend(tools)
|
||||
|
||||
def remove_tool(self, tool_name: str) -> bool:
|
||||
"""Remove a tool from the catalog by name.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to remove.
|
||||
|
||||
Returns:
|
||||
True if the tool was removed, False if not found.
|
||||
"""
|
||||
for i, tool in enumerate(self.tool_catalog):
|
||||
if tool.name == tool_name:
|
||||
self.tool_catalog.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_catalog_size(self) -> int:
|
||||
"""Get the number of tools in the catalog.
|
||||
|
||||
Returns:
|
||||
The number of tools in the catalog.
|
||||
"""
|
||||
return len(self.tool_catalog)
|
||||
|
||||
def list_tool_names(self) -> list[str]:
|
||||
"""List all tool names in the catalog.
|
||||
|
||||
Returns:
|
||||
List of tool names.
|
||||
"""
|
||||
return [tool.name for tool in self.tool_catalog]
|
||||
@@ -2,11 +2,8 @@ from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import portalocker
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
@@ -126,15 +123,10 @@ class FileHandler:
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
"""Thread-safe handler for saving and loading data using pickle.
|
||||
|
||||
This class provides thread-safe file operations using portalocker for
|
||||
cross-process file locking and atomic write operations to prevent
|
||||
data corruption during concurrent access.
|
||||
"""Handler for saving and loading data using pickle.
|
||||
|
||||
Attributes:
|
||||
file_path: The path to the pickle file.
|
||||
_lock: Threading lock for thread-safe operations within the same process.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name: str) -> None:
|
||||
@@ -149,62 +141,34 @@ class PickleHandler:
|
||||
file_name += ".pkl"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""Initialize the file with an empty dictionary and overwrite any existing data."""
|
||||
self.save({})
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
"""Save the data to the specified file using pickle with thread-safe atomic writes.
|
||||
|
||||
This method uses a two-phase approach for thread safety:
|
||||
1. Threading lock for same-process thread safety
|
||||
2. Atomic write (write to temp file, then rename) for cross-process safety
|
||||
and data integrity
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with self._lock:
|
||||
dir_name = os.path.dirname(self.file_path) or os.getcwd()
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(temp_path, self.file_path)
|
||||
except Exception:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle with thread-safe locking.
|
||||
|
||||
This method uses portalocker for cross-process read locking to ensure
|
||||
data consistency when multiple processes may be accessing the file.
|
||||
"""Load the data from the specified file using pickle.
|
||||
|
||||
Returns:
|
||||
The data loaded from the file, or an empty dictionary if the file
|
||||
does not exist or is empty.
|
||||
The data loaded from the file.
|
||||
"""
|
||||
with self._lock:
|
||||
if (
|
||||
not os.path.exists(self.file_path)
|
||||
or os.path.getsize(self.file_path) == 0
|
||||
):
|
||||
return {}
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
|
||||
with portalocker.Lock(
|
||||
self.file_path, "rb", flags=portalocker.LOCK_SH
|
||||
) as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {}
|
||||
except Exception:
|
||||
raise
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
|
||||
393
lib/crewai/tests/tools/test_tool_search_tool.py
Normal file
393
lib/crewai/tests/tools/test_tool_search_tool.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Tests for the ToolSearchTool functionality."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
|
||||
|
||||
|
||||
class MockSearchTool(BaseTool):
|
||||
"""A mock search tool for testing."""
|
||||
|
||||
name: str = "Web Search"
|
||||
description: str = "Search the web for information on any topic."
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Search results for: {query}"
|
||||
|
||||
|
||||
class MockDatabaseTool(BaseTool):
|
||||
"""A mock database tool for testing."""
|
||||
|
||||
name: str = "Database Query"
|
||||
description: str = "Query a SQL database to retrieve data."
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Database results for: {query}"
|
||||
|
||||
|
||||
class MockScrapeTool(BaseTool):
|
||||
"""A mock web scraping tool for testing."""
|
||||
|
||||
name: str = "Web Scraper"
|
||||
description: str = "Scrape content from websites and extract text."
|
||||
|
||||
def _run(self, url: str) -> str:
|
||||
return f"Scraped content from: {url}"
|
||||
|
||||
|
||||
class MockEmailTool(BaseTool):
|
||||
"""A mock email tool for testing."""
|
||||
|
||||
name: str = "Send Email"
|
||||
description: str = "Send an email to a specified recipient."
|
||||
|
||||
def _run(self, to: str, subject: str, body: str) -> str:
|
||||
return f"Email sent to {to}"
|
||||
|
||||
|
||||
class MockCalculatorTool(BaseTool):
|
||||
"""A mock calculator tool for testing."""
|
||||
|
||||
name: str = "Calculator"
|
||||
description: str = "Perform mathematical calculations and arithmetic operations."
|
||||
|
||||
def _run(self, expression: str) -> str:
|
||||
return f"Result: {eval(expression)}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools() -> list[BaseTool]:
|
||||
"""Create a list of sample tools for testing."""
|
||||
return [
|
||||
MockSearchTool(),
|
||||
MockDatabaseTool(),
|
||||
MockScrapeTool(),
|
||||
MockEmailTool(),
|
||||
MockCalculatorTool(),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
|
||||
"""Create a ToolSearchTool with sample tools."""
|
||||
return ToolSearchTool(tool_catalog=sample_tools)
|
||||
|
||||
|
||||
class TestToolSearchToolCreation:
|
||||
"""Tests for ToolSearchTool creation and initialization."""
|
||||
|
||||
def test_create_tool_search_with_empty_catalog(self) -> None:
|
||||
"""Test creating a ToolSearchTool with an empty catalog."""
|
||||
tool_search = ToolSearchTool()
|
||||
assert tool_search.name == "Tool Search"
|
||||
assert tool_search.tool_catalog == []
|
||||
assert tool_search.search_strategy == SearchStrategy.KEYWORD
|
||||
|
||||
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test creating a ToolSearchTool with a list of tools."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
assert len(tool_search.tool_catalog) == 5
|
||||
assert tool_search.get_catalog_size() == 5
|
||||
|
||||
def test_create_tool_search_with_regex_strategy(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test creating a ToolSearchTool with regex search strategy."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
assert tool_search.search_strategy == SearchStrategy.REGEX
|
||||
|
||||
def test_create_tool_search_with_custom_name(self) -> None:
|
||||
"""Test creating a ToolSearchTool with a custom name."""
|
||||
tool_search = ToolSearchTool(name="My Tool Finder")
|
||||
assert tool_search.name == "My Tool Finder"
|
||||
|
||||
|
||||
class TestToolSearchKeywordSearch:
|
||||
"""Tests for keyword-based tool search."""
|
||||
|
||||
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by its exact name."""
|
||||
result = tool_search._run("Web Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
assert result_data["tools"][0]["name"] == "Web Search"
|
||||
|
||||
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by partial name."""
|
||||
result = tool_search._run("Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Search" in tool_names
|
||||
|
||||
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by keyword in description."""
|
||||
result = tool_search._run("database")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Database Query" in tool_names
|
||||
|
||||
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching with multiple keywords."""
|
||||
result = tool_search._run("web scrape content")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Scraper" in tool_names
|
||||
|
||||
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching with a query that returns no results."""
|
||||
result = tool_search._run("xyznonexistent123abc")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "no_results"
|
||||
assert len(result_data["tools"]) == 0
|
||||
|
||||
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that max_results limits the number of returned tools."""
|
||||
result = tool_search._run("tool", max_results=2)
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) <= 2
|
||||
|
||||
def test_search_empty_catalog(self) -> None:
|
||||
"""Test searching with an empty tool catalog."""
|
||||
tool_search = ToolSearchTool()
|
||||
result = tool_search._run("search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "error"
|
||||
assert "No tools available" in result_data["message"]
|
||||
|
||||
|
||||
class TestToolSearchRegexSearch:
|
||||
"""Tests for regex-based tool search."""
|
||||
|
||||
def test_regex_search_simple_pattern(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test regex search with a simple pattern."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("Web.*")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Search" in tool_names or "Web Scraper" in tool_names
|
||||
|
||||
def test_regex_search_case_insensitive(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that regex search is case insensitive."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("email")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Send Email" in tool_names
|
||||
|
||||
def test_regex_search_invalid_pattern_fallback(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that invalid regex patterns are escaped and still work."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("[invalid(regex")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] in ["success", "no_results"]
|
||||
|
||||
|
||||
class TestToolSearchCustomSearch:
|
||||
"""Tests for custom search function."""
|
||||
|
||||
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test using a custom search function."""
|
||||
|
||||
def custom_search(
|
||||
query: str, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
return [t for t in tools if "email" in t.name.lower()]
|
||||
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, custom_search_fn=custom_search
|
||||
)
|
||||
result = tool_search._run("anything")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) == 1
|
||||
assert result_data["tools"][0]["name"] == "Send Email"
|
||||
|
||||
|
||||
class TestToolSearchCatalogManagement:
|
||||
"""Tests for tool catalog management."""
|
||||
|
||||
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test adding a tool to the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
|
||||
class NewTool(BaseTool):
|
||||
name: str = "New Tool"
|
||||
description: str = "A new tool for testing."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "New tool result"
|
||||
|
||||
tool_search.add_tool(NewTool())
|
||||
assert tool_search.get_catalog_size() == initial_size + 1
|
||||
|
||||
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test adding multiple tools to the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
|
||||
class NewTool1(BaseTool):
|
||||
name: str = "New Tool 1"
|
||||
description: str = "First new tool."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "Result 1"
|
||||
|
||||
class NewTool2(BaseTool):
|
||||
name: str = "New Tool 2"
|
||||
description: str = "Second new tool."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "Result 2"
|
||||
|
||||
tool_search.add_tools([NewTool1(), NewTool2()])
|
||||
assert tool_search.get_catalog_size() == initial_size + 2
|
||||
|
||||
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test removing a tool from the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
result = tool_search.remove_tool("Web Search")
|
||||
|
||||
assert result is True
|
||||
assert tool_search.get_catalog_size() == initial_size - 1
|
||||
|
||||
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test removing a tool that doesn't exist."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
result = tool_search.remove_tool("Nonexistent Tool")
|
||||
|
||||
assert result is False
|
||||
assert tool_search.get_catalog_size() == initial_size
|
||||
|
||||
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test listing all tool names in the catalog."""
|
||||
names = tool_search.list_tool_names()
|
||||
|
||||
assert len(names) == 5
|
||||
assert "Web Search" in names
|
||||
assert "Database Query" in names
|
||||
assert "Web Scraper" in names
|
||||
assert "Send Email" in names
|
||||
assert "Calculator" in names
|
||||
|
||||
|
||||
class TestToolSearchResultFormat:
|
||||
"""Tests for the format of search results."""
|
||||
|
||||
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that search results contain complete tool information."""
|
||||
result = tool_search._run("Calculator")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_info = result_data["tools"][0]
|
||||
|
||||
assert "name" in tool_info
|
||||
assert "description" in tool_info
|
||||
assert "args_schema" in tool_info
|
||||
assert tool_info["name"] == "Calculator"
|
||||
|
||||
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that args_schema is properly formatted."""
|
||||
result = tool_search._run("Email")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_info = result_data["tools"][0]
|
||||
|
||||
assert "args_schema" in tool_info
|
||||
args_schema = tool_info["args_schema"]
|
||||
assert isinstance(args_schema, dict)
|
||||
|
||||
|
||||
class TestToolSearchIntegration:
|
||||
"""Integration tests for ToolSearchTool."""
|
||||
|
||||
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test that ToolSearchTool works as a BaseTool."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
|
||||
assert isinstance(tool_search, BaseTool)
|
||||
assert tool_search.name == "Tool Search"
|
||||
assert "search" in tool_search.description.lower()
|
||||
|
||||
def test_tool_search_to_structured_tool(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test converting ToolSearchTool to structured tool."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
structured = tool_search.to_structured_tool()
|
||||
|
||||
assert structured.name == "Tool Search"
|
||||
assert structured.args_schema is not None
|
||||
|
||||
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test the run method of ToolSearchTool."""
|
||||
result = tool_search.run(query="search", max_results=3)
|
||||
|
||||
assert isinstance(result, str)
|
||||
result_data = json.loads(result)
|
||||
assert "status" in result_data
|
||||
assert "tools" in result_data
|
||||
|
||||
|
||||
class TestToolSearchScoring:
|
||||
"""Tests for the keyword scoring algorithm."""
|
||||
|
||||
def test_exact_name_match_scores_highest(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that exact name matches score higher than partial matches."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
result = tool_search._run("Web Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert result_data["tools"][0]["name"] == "Web Search"
|
||||
|
||||
def test_name_match_scores_higher_than_description(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that name matches score higher than description matches."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
result = tool_search._run("Calculator")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert result_data["tools"][0]["name"] == "Calculator"
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
@@ -10,6 +8,7 @@ from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a unique file name for each test to avoid race conditions in parallel test execution
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_data_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
@@ -48,234 +47,3 @@ class TestPickleHandler(unittest.TestCase):
|
||||
|
||||
assert str(exc.value) == "pickle data was truncated"
|
||||
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
|
||||
|
||||
|
||||
class TestPickleHandlerThreadSafety(unittest.TestCase):
|
||||
"""Tests for thread-safety of PickleHandler operations."""
|
||||
|
||||
def setUp(self):
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_thread_safe_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
|
||||
def test_concurrent_writes_same_handler(self):
|
||||
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
|
||||
num_threads = 10
|
||||
num_writes_per_thread = 20
|
||||
errors: list[Exception] = []
|
||||
write_count = 0
|
||||
count_lock = threading.Lock()
|
||||
|
||||
def write_data(thread_id: int) -> None:
|
||||
nonlocal write_count
|
||||
for i in range(num_writes_per_thread):
|
||||
try:
|
||||
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
|
||||
self.handler.save(data)
|
||||
with count_lock:
|
||||
write_count += 1
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
t = threading.Thread(target=write_data, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
|
||||
assert write_count == num_threads * num_writes_per_thread
|
||||
loaded_data = self.handler.load()
|
||||
assert isinstance(loaded_data, dict)
|
||||
assert "thread" in loaded_data
|
||||
assert "iteration" in loaded_data
|
||||
|
||||
def test_concurrent_reads_same_handler(self):
|
||||
"""Test that concurrent reads from multiple threads don't cause issues."""
|
||||
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
|
||||
self.handler.save(test_data)
|
||||
|
||||
num_threads = 20
|
||||
results: list[dict] = []
|
||||
errors: list[Exception] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def read_data() -> None:
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
results.append(data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=read_data)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
|
||||
assert len(results) == num_threads
|
||||
for result in results:
|
||||
assert result == test_data
|
||||
|
||||
def test_concurrent_read_write_same_handler(self):
|
||||
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
|
||||
initial_data = {"counter": 0}
|
||||
self.handler.save(initial_data)
|
||||
|
||||
num_writers = 5
|
||||
num_readers = 10
|
||||
writes_per_thread = 10
|
||||
reads_per_thread = 20
|
||||
write_errors: list[Exception] = []
|
||||
read_errors: list[Exception] = []
|
||||
read_results: list[dict] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def writer(thread_id: int) -> None:
|
||||
for i in range(writes_per_thread):
|
||||
try:
|
||||
data = {"writer": thread_id, "write_num": i}
|
||||
self.handler.save(data)
|
||||
except Exception as e:
|
||||
write_errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
for _ in range(reads_per_thread):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
read_results.append(data)
|
||||
except Exception as e:
|
||||
read_errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_writers):
|
||||
t = threading.Thread(target=writer, args=(i,))
|
||||
threads.append(t)
|
||||
|
||||
for _ in range(num_readers):
|
||||
t = threading.Thread(target=reader)
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(write_errors) == 0, f"Write errors: {write_errors}"
|
||||
assert len(read_errors) == 0, f"Read errors: {read_errors}"
|
||||
for result in read_results:
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_atomic_write_no_partial_data(self):
|
||||
"""Test that atomic writes prevent partial/corrupted data from being read."""
|
||||
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
|
||||
num_iterations = 50
|
||||
errors: list[Exception] = []
|
||||
corruption_detected = False
|
||||
corruption_lock = threading.Lock()
|
||||
|
||||
def writer() -> None:
|
||||
for _ in range(num_iterations):
|
||||
try:
|
||||
self.handler.save(large_data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
nonlocal corruption_detected
|
||||
for _ in range(num_iterations * 2):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
if data and data != {} and data != large_data:
|
||||
with corruption_lock:
|
||||
corruption_detected = True
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
|
||||
writer_thread.start()
|
||||
reader_thread.start()
|
||||
|
||||
writer_thread.join()
|
||||
reader_thread.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert not corruption_detected, "Partial/corrupted data was read"
|
||||
|
||||
def test_thread_pool_concurrent_operations(self):
|
||||
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
|
||||
num_operations = 100
|
||||
errors: list[Exception] = []
|
||||
|
||||
def operation(op_id: int) -> str:
|
||||
try:
|
||||
if op_id % 3 == 0:
|
||||
self.handler.save({"op_id": op_id, "type": "write"})
|
||||
return f"write_{op_id}"
|
||||
else:
|
||||
data = self.handler.load()
|
||||
return f"read_{op_id}_{type(data).__name__}"
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
return f"error_{op_id}"
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
futures = [executor.submit(operation, i) for i in range(num_operations)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == num_operations
|
||||
|
||||
def test_multiple_handlers_same_file(self):
|
||||
"""Test that multiple PickleHandler instances for the same file work correctly."""
|
||||
handler1 = PickleHandler(self.file_name)
|
||||
handler2 = PickleHandler(self.file_name)
|
||||
|
||||
num_operations = 50
|
||||
errors: list[Exception] = []
|
||||
|
||||
def use_handler1() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler1.save({"handler": 1, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def use_handler2() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler2.save({"handler": 2, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
t1 = threading.Thread(target=use_handler1)
|
||||
t2 = threading.Thread(target=use_handler2)
|
||||
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
final_data = self.handler.load()
|
||||
assert isinstance(final_data, dict)
|
||||
assert "handler" in final_data
|
||||
assert "iteration" in final_data
|
||||
|
||||
Reference in New Issue
Block a user