Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
5db807b57c Fix import sorting with ruff
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-21 12:39:17 +00:00
Devin AI
710d20a66e Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-21 12:38:26 +00:00
Devin AI
245399bca0 Fix issue #2434: Allow tools to specify if they permit repeated usage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-21 12:35:52 +00:00
18 changed files with 827 additions and 387 deletions

View File

@@ -59,7 +59,7 @@ There are three ways to configure LLMs in CrewAI. Choose the method that best fi
goal: Conduct comprehensive research and analysis
backstory: A dedicated research professional with years of experience
verbose: true
llm: openai/gpt-4o-mini # your model here
llm: openai/gpt-4o-mini # your model here
# (see provider configuration examples below for more)
```
@@ -111,7 +111,7 @@ There are three ways to configure LLMs in CrewAI. Choose the method that best fi
## Provider Configuration Examples
CrewAI supports a multitude of LLM providers, each offering unique features, authentication methods, and model capabilities.
CrewAI supports a multitude of LLM providers, each offering unique features, authentication methods, and model capabilities.
In this section, you'll find detailed examples that help you select, configure, and optimize the LLM that best fits your project's needs.
<AccordionGroup>
@@ -121,7 +121,7 @@ In this section, you'll find detailed examples that help you select, configure,
```toml Code
# Required
OPENAI_API_KEY=sk-...
# Optional
OPENAI_API_BASE=<custom-base-url>
OPENAI_ORGANIZATION=<your-org-id>
@@ -226,7 +226,7 @@ In this section, you'll find detailed examples that help you select, configure,
AZURE_API_KEY=<your-api-key>
AZURE_API_BASE=<your-resource-url>
AZURE_API_VERSION=<api-version>
# Optional
AZURE_AD_TOKEN=<your-azure-ad-token>
AZURE_API_TYPE=<your-azure-api-type>
@@ -289,7 +289,7 @@ In this section, you'll find detailed examples that help you select, configure,
| Mistral 8x7B Instruct | Up to 32k tokens | An MOE LLM that follows instructions, completes requests, and generates creative text. |
</Accordion>
<Accordion title="Amazon SageMaker">
```toml Code
AWS_ACCESS_KEY_ID=<your-access-key>
@@ -474,7 +474,7 @@ In this section, you'll find detailed examples that help you select, configure,
WATSONX_URL=<your-url>
WATSONX_APIKEY=<your-apikey>
WATSONX_PROJECT_ID=<your-project-id>
# Optional
WATSONX_TOKEN=<your-token>
WATSONX_DEPLOYMENT_SPACE_ID=<your-space-id>
@@ -491,7 +491,7 @@ In this section, you'll find detailed examples that help you select, configure,
<Accordion title="Ollama (Local LLMs)">
1. Install Ollama: [ollama.ai](https://ollama.ai/)
2. Run a model: `ollama run llama3`
2. Run a model: `ollama run llama2`
3. Configure:
```python Code
@@ -600,7 +600,7 @@ In this section, you'll find detailed examples that help you select, configure,
```toml Code
OPENROUTER_API_KEY=<your-api-key>
```
Example usage in your CrewAI project:
```python Code
llm = LLM(
@@ -723,7 +723,7 @@ Learn how to get the most out of your LLM configuration:
- Small tasks (up to 4K tokens): Standard models
- Medium tasks (between 4K-32K): Enhanced models
- Large tasks (over 32K): Large context models
```python
# Configure model with appropriate settings
llm = LLM(
@@ -760,11 +760,11 @@ Learn how to get the most out of your LLM configuration:
<Warning>
Most authentication issues can be resolved by checking API key format and environment variable names.
</Warning>
```bash
# OpenAI
OPENAI_API_KEY=sk-...
# Anthropic
ANTHROPIC_API_KEY=sk-ant-...
```
@@ -773,11 +773,11 @@ Learn how to get the most out of your LLM configuration:
<Check>
Always include the provider prefix in model names
</Check>
```python
# Correct
llm = LLM(model="openai/gpt-4")
# Incorrect
llm = LLM(model="gpt-4")
```
@@ -786,10 +786,5 @@ Learn how to get the most out of your LLM configuration:
<Tip>
Use larger context models for extensive tasks
</Tip>
```python
# Large context model
llm = LLM(model="openai/gpt-4o") # 128K tokens
```
</Tab>
</Tabs>

View File

@@ -300,7 +300,7 @@ email_summarizer:
```
<Tip>
Note how we use the same name for the task in the `tasks.yaml` (`email_summarizer_task`) file as the method name in the `crew.py` (`email_summarizer_task`) file.
Note how we use the same name for the agent in the `tasks.yaml` (`email_summarizer_task`) file as the method name in the `crew.py` (`email_summarizer_task`) file.
</Tip>
```yaml tasks.yaml

View File

@@ -17,9 +17,9 @@ dependencies = [
"pdfplumber>=0.11.4",
"regex>=2024.9.11",
# Telemetry and Monitoring
"opentelemetry-api>=1.30.0",
"opentelemetry-sdk>=1.30.0",
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
"opentelemetry-api>=1.22.0",
"opentelemetry-sdk>=1.22.0",
"opentelemetry-exporter-otlp-proto-http>=1.22.0",
# Data Handling
"chromadb>=0.5.23",
"openpyxl>=3.1.5",

View File

@@ -136,7 +136,7 @@ class CrewAgentParser:
def _clean_action(self, text: str) -> str:
"""Clean action string by removing non-essential formatting characters."""
return text.strip().strip("*").strip()
return re.sub(r"^\s*\*+\s*|\s*\*+\s*$", "", text).strip()
def _safe_repair_json(self, tool_input: str) -> str:
UNABLE_TO_REPAIR_JSON_RESULTS = ['""', "{}"]

View File

@@ -1,5 +1,4 @@
import subprocess
from functools import lru_cache
class Repository:
@@ -36,7 +35,6 @@ class Repository:
encoding="utf-8",
).strip()
@lru_cache(maxsize=None)
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository."""
try:

View File

@@ -8,45 +8,45 @@ from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel],
state_data: Union[Dict[str, Any], BaseModel]
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""

View File

@@ -11,7 +11,6 @@ from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.state_utils import to_serializable
class SQLiteFlowPersistence(FlowPersistence):
@@ -79,53 +78,34 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
Raises:
ValueError: If state_data is neither a dict nor a BaseModel
RuntimeError: If database operations fail
TypeError: If JSON serialization fails
"""
try:
# Convert state_data to a JSON-serializable dict using the helper method
state_dict = to_serializable(state_data)
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = dict(state_data) # Use dict() for better type compatibility
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
# Try to serialize to JSON to catch any serialization issues early
try:
state_json = json.dumps(state_dict)
except (TypeError, ValueError, OverflowError) as json_err:
raise TypeError(
f"Failed to serialize state to JSON: {json_err}"
) from json_err
# Perform database operation with error handling
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
state_json,
),
)
except sqlite3.Error as db_err:
raise RuntimeError(f"Database operation failed: {db_err}") from db_err
except Exception as e:
# Log the error but don't crash the application
import logging
logging.error(f"Failed to save flow state: {e}")
# Re-raise to allow caller to handle or ignore
raise
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.

View File

@@ -1,16 +1,36 @@
import json
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Union
from pydantic import BaseModel
from crewai.flow import Flow
SerializablePrimitive = Union[str, int, float, bool, None]
Serializable = Union[
SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"]
]
def export_state(flow: Flow) -> dict[str, Serializable]:
"""Exports the Flow's internal state as JSON-compatible data structures.
Performs a one-way transformation of a Flow's state into basic Python types
that can be safely serialized to JSON. To prevent infinite recursion with
circular references, the conversion is limited to a depth of 5 levels.
Args:
flow: The Flow object whose state needs to be exported
Returns:
dict[str, Any]: The transformed state using JSON-compatible Python
types.
"""
result = to_serializable(flow._state)
assert isinstance(result, dict)
return result
def to_serializable(
obj: Any, max_depth: int = 5, _current_depth: int = 0
) -> Serializable:
@@ -32,8 +52,6 @@ def to_serializable(
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, Enum):
return obj.value
elif isinstance(obj, (date, datetime)):
return obj.isoformat()
elif isinstance(obj, (list, tuple, set)):

View File

@@ -114,60 +114,6 @@ LLM_CONTEXT_WINDOW_SIZES = {
"Llama-3.2-11B-Vision-Instruct": 16384,
"Meta-Llama-3.2-3B-Instruct": 4096,
"Meta-Llama-3.2-1B-Instruct": 16384,
# bedrock
"us.amazon.nova-pro-v1:0": 300000,
"us.amazon.nova-micro-v1:0": 128000,
"us.amazon.nova-lite-v1:0": 300000,
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"us.anthropic.claude-3-5-haiku-20241022-v1:0": 200000,
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": 200000,
"us.anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"us.anthropic.claude-3-opus-20240229-v1:0": 200000,
"us.anthropic.claude-3-haiku-20240307-v1:0": 200000,
"us.meta.llama3-2-11b-instruct-v1:0": 128000,
"us.meta.llama3-2-3b-instruct-v1:0": 131000,
"us.meta.llama3-2-90b-instruct-v1:0": 128000,
"us.meta.llama3-2-1b-instruct-v1:0": 131000,
"us.meta.llama3-1-8b-instruct-v1:0": 128000,
"us.meta.llama3-1-70b-instruct-v1:0": 128000,
"us.meta.llama3-3-70b-instruct-v1:0": 128000,
"us.meta.llama3-1-405b-instruct-v1:0": 128000,
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"eu.anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"eu.anthropic.claude-3-haiku-20240307-v1:0": 200000,
"eu.meta.llama3-2-3b-instruct-v1:0": 131000,
"eu.meta.llama3-2-1b-instruct-v1:0": 131000,
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
"apac.anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"apac.anthropic.claude-3-haiku-20240307-v1:0": 200000,
"amazon.nova-pro-v1:0": 300000,
"amazon.nova-micro-v1:0": 128000,
"amazon.nova-lite-v1:0": 300000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-5-haiku-20241022-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-v2": 100000,
"anthropic.claude-instant-v1": 100000,
"meta.llama3-1-405b-instruct-v1:0": 128000,
"meta.llama3-1-70b-instruct-v1:0": 128000,
"meta.llama3-1-8b-instruct-v1:0": 128000,
"meta.llama3-70b-instruct-v1:0": 8000,
"meta.llama3-8b-instruct-v1:0": 8000,
"amazon.titan-text-lite-v1": 4000,
"amazon.titan-text-express-v1": 8000,
"cohere.command-text-v14": 4000,
"ai21.j2-mid-v1": 8191,
"ai21.j2-ultra-v1": 8191,
"ai21.jamba-instruct-v1:0": 256000,
"mistral.mistral-7b-instruct-v0:2": 32000,
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
# mistral
"mistral-tiny": 32768,
"mistral-small-latest": 32768,

View File

@@ -1,7 +1,7 @@
import os
from typing import Any, Dict, List
from mem0 import Memory, MemoryClient
from mem0 import MemoryClient
from crewai.memory.storage.interface import Storage
@@ -32,16 +32,13 @@ class Mem0Storage(Storage):
mem0_org_id = config.get("org_id")
mem0_project_id = config.get("project_id")
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
)
else:
self.memory = MemoryClient(api_key=mem0_api_key)
# Initialize MemoryClient with available parameters
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
)
else:
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
self.memory = MemoryClient(api_key=mem0_api_key)
def _sanitize_role(self, role: str) -> str:
"""

View File

@@ -281,16 +281,8 @@ class Telemetry:
return self._safe_telemetry_operation(operation)
def task_ended(self, span: Span, task: Task, crew: Crew):
"""Records the completion of a task execution in a crew.
"""Records task execution in a crew."""
Args:
span (Span): The OpenTelemetry span tracking the task execution
task (Task): The task that was completed
crew (Crew): The crew context in which the task was executed
Note:
If share_crew is enabled, this will also record the task output
"""
def operation():
if crew.share_crew:
self._add_attribute(
@@ -305,13 +297,8 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int):
"""Records when a tool is used repeatedly, which might indicate an issue.
"""Records the repeated usage 'error' of a tool by an agent."""
Args:
llm (Any): The language model being used
tool_name (str): Name of the tool being repeatedly used
attempts (int): Number of attempts made with this tool
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Repeated Usage")
@@ -330,13 +317,8 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def tool_usage(self, llm: Any, tool_name: str, attempts: int):
"""Records the usage of a tool by an agent.
"""Records the usage of a tool by an agent."""
Args:
llm (Any): The language model being used
tool_name (str): Name of the tool being used
attempts (int): Number of attempts made with this tool
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage")
@@ -355,11 +337,8 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def tool_usage_error(self, llm: Any):
"""Records when a tool usage results in an error.
"""Records the usage of a tool by an agent."""
Args:
llm (Any): The language model being used when the error occurred
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage Error")
@@ -378,14 +357,6 @@ class Telemetry:
def individual_test_result_span(
self, crew: Crew, quality: float, exec_time: int, model_name: str
):
"""Records individual test results for a crew execution.
Args:
crew (Crew): The crew being tested
quality (float): Quality score of the execution
exec_time (int): Execution time in seconds
model_name (str): Name of the model used
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Individual Test Result")
@@ -412,14 +383,6 @@ class Telemetry:
inputs: dict[str, Any] | None,
model_name: str,
):
"""Records the execution of a test suite for a crew.
Args:
crew (Crew): The crew being tested
iterations (int): Number of test iterations
inputs (dict[str, Any] | None): Input parameters for the test
model_name (str): Name of the model used in testing
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Test Execution")
@@ -445,7 +408,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def deploy_signup_error_span(self):
"""Records when an error occurs during the deployment signup process."""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Deploy Signup Error")
@@ -455,11 +417,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def start_deployment_span(self, uuid: Optional[str] = None):
"""Records the start of a deployment process.
Args:
uuid (Optional[str]): Unique identifier for the deployment
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Start Deployment")
@@ -471,7 +428,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def create_crew_deployment_span(self):
"""Records the creation of a new crew deployment."""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Create Crew Deployment")
@@ -481,12 +437,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"):
"""Records the retrieval of crew logs.
Args:
uuid (Optional[str]): Unique identifier for the crew
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Get Crew Logs")
@@ -499,11 +449,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def remove_crew_span(self, uuid: Optional[str] = None):
"""Records the removal of a crew.
Args:
uuid (Optional[str]): Unique identifier for the crew being removed
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Remove Crew")
@@ -629,11 +574,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def flow_creation_span(self, flow_name: str):
"""Records the creation of a new flow.
Args:
flow_name (str): Name of the flow being created
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Creation")
@@ -644,12 +584,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
"""Records flow visualization/plotting activity.
Args:
flow_name (str): Name of the flow being plotted
node_names (list[str]): List of node names in the flow
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Plotting")
@@ -661,12 +595,6 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def flow_execution_span(self, flow_name: str, node_names: list[str]):
"""Records the execution of a flow.
Args:
flow_name (str): Name of the flow being executed
node_names (list[str]): List of nodes being executed in the flow
"""
def operation():
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Execution")

View File

@@ -37,6 +37,8 @@ class BaseTool(BaseModel, ABC):
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
result_as_answer: bool = False
"""Flag to check if the tool should be the final agent answer."""
allow_repeated_usage: bool = False
"""Whether the tool permits repeated usage with same arguments."""
@validator("args_schema", always=True, pre=True)
def _default_args_schema(

View File

@@ -279,6 +279,10 @@ class ToolUsage:
if not self.tools_handler:
return False # type: ignore # No return value expected
if last_tool_usage := self.tools_handler.last_used_tool:
tool = self._select_tool(calling.tool_name)
# If the tool allows repeated usage, don't check arguments
if getattr(tool, "allow_repeated_usage", False):
return False # type: ignore # No return value expected
return (calling.tool_name == last_tool_usage.tool_name) and ( # type: ignore # No return value expected
calling.arguments == last_tool_usage.arguments
)
@@ -455,7 +459,7 @@ class ToolUsage:
# Attempt 4: Repair JSON
try:
repaired_input = repair_json(tool_input, skip_json_loads=True)
repaired_input = repair_json(tool_input)
self._printer.print(
content=f"Repaired JSON: {repaired_input}", color="blue"
)

View File

@@ -96,10 +96,6 @@ class CrewPlanner:
tasks_summary = []
for idx, task in enumerate(self.tasks):
knowledge_list = self._get_agent_knowledge(task)
agent_tools = (
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else ""
)
task_summary = f"""
Task Number {idx + 1} - {task.description}
"task_description": {task.description}
@@ -107,7 +103,10 @@ class CrewPlanner:
"agent": {task.agent.role if task.agent else "None"}
"agent_goal": {task.agent.goal if task.agent else "None"}
"task_tools": {task.tools}
"agent_tools": {"".join(agent_tools)}"""
"agent_tools": %s%s""" % (
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else ""
)
tasks_summary.append(task_summary)
return " ".join(tasks_summary)

View File

@@ -546,6 +546,47 @@ def test_agent_moved_on_after_max_iterations():
assert output == "42"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage_respects_allow_repeated_usage(capsys):
@tool
def repeatable_tool(anything: str) -> float:
"""A tool that allows being used repeatedly with the same input."""
return 42
# Patch the tool to set allow_repeated_usage to True
repeatable_tool.allow_repeated_usage = True
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
max_iter=4,
llm="gpt-4",
allow_delegation=False,
verbose=True,
)
task = Task(
description="Use the repeatable tool with the same input multiple times.",
expected_output="The result of using the repeatable tool",
)
# force cleaning cache
agent.tools_handler.cache = CacheHandler()
agent.execute_task(
task=task,
tools=[repeatable_tool],
)
captured = capsys.readouterr()
# Should NOT show the repeated usage error
assert (
"I tried reusing the same input, I must stop using this action input. I'll try something else instead."
not in captured.out
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_respect_the_max_rpm_set(capsys):
@tool

View File

@@ -0,0 +1,95 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
personal goal is: test goal\nYou ONLY have access to the following tools, and
should NEVER make up tools that are not listed here:\n\nTool Name: repeatable_tool\nTool
Arguments: {''anything'': {''description'': None, ''type'': ''str''}}\nTool
Description: A tool that allows being used repeatedly with the same input.\n\nIMPORTANT:
Use the following format in your response:\n\n```\nThought: you should always
think about what to do\nAction: the action to take, only one name of [repeatable_tool],
just the name, exactly as it''s written.\nAction Input: the input to the action,
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
is gathered, return the following format:\n\n```\nThought: I now know the final
answer\nFinal Answer: the final answer to the original input question\n```"},
{"role": "user", "content": "\nCurrent Task: Use the repeatable tool with the
same input multiple times.\n\nThis is the expected criteria for your final answer:
The result of using the repeatable tool\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}], "model": "gpt-4", "stop": ["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '1443'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- x64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.61.0
x-stainless-raw-response:
- 'true'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
sk-proj-********************************************************************************************************************************************************sLcA.
You can find your API key at https://platform.openai.com/account/api-keys.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
\"invalid_api_key\"\n }\n}\n"
headers:
CF-RAY:
- 923d7d097e94585c-SEA
Connection:
- keep-alive
Content-Length:
- '414'
Content-Type:
- application/json; charset=utf-8
Date:
- Fri, 21 Mar 2025 12:35:18 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=Q1SICHUtjtv5VIyjejhOK84VaDt9c0W.OVC6v_gypkQ-1742560518-1.0.1.1-qJY3vQUXlr.VsRsaGGOWSlwiIAp08q3Lt8WnlqIyScSZLPbR0lKV.af50DgwmKKgMtmbt3i27M3b_InDvj8zqEeADyauevHK67qwXvnimEo;
path=/; expires=Fri, 21-Mar-25 13:05:18 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=BrdkkJU.eT5POa3a8EbLDfLnncr8znx1G52s.PDhIOE-1742560518752-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
vary:
- Origin
x-request-id:
- req_88d97a89c4b41a4e306821c931177663
http_version: HTTP/1.1
status_code: 401
version: 1

View File

@@ -0,0 +1,49 @@
from unittest.mock import MagicMock
import pytest
from crewai.tools import BaseTool
from crewai.tools.tool_calling import ToolCalling
from crewai.tools.tool_usage import ToolUsage
def test_tool_repeated_usage_allowed():
"""Test that a tool with allow_repeated_usage=True can be used repeatedly with same args."""
class RepeatedUsageTool(BaseTool):
name: str = "Repeated Usage Tool"
description: str = "A tool that can be used repeatedly with the same arguments"
allow_repeated_usage: bool = True
def _run(self, test_arg: str) -> str:
return f"Used with arg: {test_arg}"
# Setup tool usage
tool = RepeatedUsageTool()
tools_handler = MagicMock()
tools_handler.last_used_tool = ToolCalling(
tool_name="Repeated Usage Tool",
arguments={"test_arg": "test"}
)
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=[tool],
original_tools=[tool],
tools_description="Test tools",
tools_names="Repeated Usage Tool",
agent=MagicMock(),
task=MagicMock(),
function_calling_llm=MagicMock(),
action=MagicMock(),
)
# Create a new tool calling with the same arguments
calling = ToolCalling(
tool_name="Repeated Usage Tool",
arguments={"test_arg": "test"}
)
# This should return False since the tool allows repeated usage
result = tool_usage._check_tool_repeated_usage(calling=calling)
assert result is False

View File

@@ -1,17 +1,35 @@
import json
import os
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union, cast
from typing import Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import BaseModel
from crewai.flow.state_utils import _to_serializable_key, to_serializable, to_string
from crewai.llm import LLM
from crewai.utilities.converter import (
Converter,
ConverterError,
convert_to_model,
convert_with_instructions,
create_converter,
generate_model_description,
get_conversion_instructions,
handle_partial_json,
validate_model,
)
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
# Sample Pydantic models for testing
class EmailResponse(BaseModel):
previous_message_content: str
class EmailResponses(BaseModel):
responses: list[EmailResponse]
class SimpleModel(BaseModel):
name: str
age: int
@@ -34,190 +52,560 @@ class Person(BaseModel):
address: Address
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
class CustomConverter(Converter):
pass
class EnumModel(BaseModel):
name: str
color: Color
# Fixtures
@pytest.fixture
def mock_agent():
agent = Mock()
agent.function_calling_llm = None
agent.llm = Mock()
return agent
class OptionalModel(BaseModel):
name: str
age: Optional[int]
# Tests for convert_to_model
def test_convert_to_model_with_valid_json():
result = '{"name": "John", "age": 30}'
output = convert_to_model(result, SimpleModel, None, None)
assert isinstance(output, SimpleModel)
assert output.name == "John"
assert output.age == 30
class ListModel(BaseModel):
items: List[int]
def test_convert_to_model_with_invalid_json():
result = '{"name": "John", "age": "thirty"}'
with patch("crewai.utilities.converter.handle_partial_json") as mock_handle:
mock_handle.return_value = "Fallback result"
output = convert_to_model(result, SimpleModel, None, None)
assert output == "Fallback result"
class UnionModel(BaseModel):
field: Union[int, str, None]
def test_convert_to_model_with_no_model():
result = "Plain text"
output = convert_to_model(result, None, None, None)
assert output == "Plain text"
# Tests for to_serializable function
def test_to_serializable_primitives():
"""Test serialization of primitive types."""
assert to_serializable("test string") == "test string"
assert to_serializable(42) == 42
assert to_serializable(3.14) == 3.14
assert to_serializable(True) == True
assert to_serializable(None) is None
def test_to_serializable_dates():
"""Test serialization of date and datetime objects."""
test_date = date(2023, 1, 15)
test_datetime = datetime(2023, 1, 15, 10, 30, 45)
assert to_serializable(test_date) == "2023-01-15"
assert to_serializable(test_datetime) == "2023-01-15T10:30:45"
def test_to_serializable_collections():
"""Test serialization of lists, tuples, and sets."""
test_list = [1, "two", 3.0]
test_tuple = (4, "five", 6.0)
test_set = {7, "eight", 9.0}
assert to_serializable(test_list) == [1, "two", 3.0]
assert to_serializable(test_tuple) == [4, "five", 6.0]
# For sets, we can't rely on order, so we'll verify differently
serialized_set = to_serializable(test_set)
assert isinstance(serialized_set, list)
assert len(serialized_set) == 3
assert 7 in serialized_set
assert "eight" in serialized_set
assert 9.0 in serialized_set
def test_to_serializable_dict():
"""Test serialization of dictionaries."""
test_dict = {"a": 1, "b": "two", "c": [3, 4, 5]}
assert to_serializable(test_dict) == {"a": 1, "b": "two", "c": [3, 4, 5]}
def test_to_serializable_pydantic_models():
"""Test serialization of Pydantic models."""
simple = SimpleModel(name="John", age=30)
assert to_serializable(simple) == {"name": "John", "age": 30}
def test_to_serializable_nested_models():
"""Test serialization of nested Pydantic models."""
simple = SimpleModel(name="John", age=30)
nested = NestedModel(id=1, data=simple)
assert to_serializable(nested) == {"id": 1, "data": {"name": "John", "age": 30}}
def test_to_serializable_complex_model():
"""Test serialization of a complex model with nested structures."""
person = Person(
name="Jane",
age=28,
address=Address(street="123 Main St", city="Anytown", zip_code="12345"),
def test_convert_to_model_with_special_characters():
json_string_test = """
{
"responses": [
{
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
}
]
}
"""
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
)
assert to_serializable(person) == {
"name": "Jane",
"age": 28,
"address": {"street": "123 Main St", "city": "Anytown", "zip_code": "12345"},
def test_convert_to_model_with_escaped_special_characters():
json_string_test = json.dumps(
{
"responses": [
{
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
}
]
}
)
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
)
def test_convert_to_model_with_multiple_special_characters():
json_string_test = """
{
"responses": [
{
"previous_message_content": "Line 1\r\nLine 2\tTabbed\nLine 3\r\n\rEscaped newline"
}
]
}
"""
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Line 1\r\nLine 2\tTabbed\nLine 3\r\n\rEscaped newline"
)
def test_to_serializable_enum():
"""Test serialization of Enum values."""
model = EnumModel(name="ColorTest", color=Color.RED)
assert to_serializable(model) == {"name": "ColorTest", "color": "red"}
# Tests for validate_model
def test_validate_model_pydantic_output():
result = '{"name": "Alice", "age": 25}'
output = validate_model(result, SimpleModel, False)
assert isinstance(output, SimpleModel)
assert output.name == "Alice"
assert output.age == 25
def test_to_serializable_optional_fields():
"""Test serialization of models with optional fields."""
model_with_age = OptionalModel(name="WithAge", age=25)
model_without_age = OptionalModel(name="WithoutAge", age=None)
assert to_serializable(model_with_age) == {"name": "WithAge", "age": 25}
assert to_serializable(model_without_age) == {"name": "WithoutAge", "age": None}
def test_validate_model_json_output():
result = '{"name": "Bob", "age": 40}'
output = validate_model(result, SimpleModel, True)
assert isinstance(output, dict)
assert output == {"name": "Bob", "age": 40}
def test_to_serializable_list_field():
"""Test serialization of models with list fields."""
model = ListModel(items=[1, 2, 3, 4, 5])
assert to_serializable(model) == {"items": [1, 2, 3, 4, 5]}
# Tests for handle_partial_json
def test_handle_partial_json_with_valid_partial():
result = 'Some text {"name": "Charlie", "age": 35} more text'
output = handle_partial_json(result, SimpleModel, False, None)
assert isinstance(output, SimpleModel)
assert output.name == "Charlie"
assert output.age == 35
def test_to_serializable_union_field():
"""Test serialization of models with union fields."""
model_int = UnionModel(field=42)
model_str = UnionModel(field="test")
model_none = UnionModel(field=None)
assert to_serializable(model_int) == {"field": 42}
assert to_serializable(model_str) == {"field": "test"}
assert to_serializable(model_none) == {"field": None}
def test_handle_partial_json_with_invalid_partial(mock_agent):
result = "No valid JSON here"
with patch("crewai.utilities.converter.convert_with_instructions") as mock_convert:
mock_convert.return_value = "Converted result"
output = handle_partial_json(result, SimpleModel, False, mock_agent)
assert output == "Converted result"
def test_to_serializable_max_depth():
"""Test max depth parameter to prevent infinite recursion."""
# Create recursive structure
a: Dict[str, Any] = {"name": "a"}
b: Dict[str, Any] = {"name": "b", "ref": a}
a["ref"] = b # Create circular reference
# Tests for convert_with_instructions
@patch("crewai.utilities.converter.create_converter")
@patch("crewai.utilities.converter.get_conversion_instructions")
def test_convert_with_instructions_success(
mock_get_instructions, mock_create_converter, mock_agent
):
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = SimpleModel(name="David", age=50)
mock_create_converter.return_value = mock_converter
result = to_serializable(a, max_depth=3)
result = "Some text to convert"
output = convert_with_instructions(result, SimpleModel, False, mock_agent)
assert isinstance(result, dict)
assert "name" in result
assert "ref" in result
assert isinstance(result["ref"], dict)
assert "ref" in result["ref"]
assert isinstance(result["ref"]["ref"], dict)
# At depth 3, it should convert to string
assert isinstance(result["ref"]["ref"]["ref"], str)
assert isinstance(output, SimpleModel)
assert output.name == "David"
assert output.age == 50
def test_to_serializable_non_serializable():
"""Test serialization of objects that aren't directly JSON serializable."""
@patch("crewai.utilities.converter.create_converter")
@patch("crewai.utilities.converter.get_conversion_instructions")
def test_convert_with_instructions_failure(
mock_get_instructions, mock_create_converter, mock_agent
):
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = ConverterError("Conversion failed")
mock_create_converter.return_value = mock_converter
class CustomObject:
def __repr__(self):
return "CustomObject()"
obj = CustomObject()
# Should convert to string representation
assert to_serializable(obj) == "CustomObject()"
result = "Some text to convert"
with patch("crewai.utilities.converter.Printer") as mock_printer:
output = convert_with_instructions(result, SimpleModel, False, mock_agent)
assert output == result
mock_printer.return_value.print.assert_called_once()
def test_to_string_conversion():
"""Test the to_string function."""
test_dict = {"name": "Test", "values": [1, 2, 3]}
# Should convert to a JSON string
assert to_string(test_dict) == '{"name": "Test", "values": [1, 2, 3]}'
# None should return None
assert to_string(None) is None
# Tests for get_conversion_instructions
def test_get_conversion_instructions_gpt():
llm = LLM(model="gpt-4o-mini")
with patch.object(LLM, "supports_function_calling") as supports_function_calling:
supports_function_calling.return_value = True
instructions = get_conversion_instructions(SimpleModel, llm)
model_schema = PydanticSchemaParser(model=SimpleModel).get_schema()
expected_instructions = (
"Please convert the following text into valid JSON.\n\n"
"Output ONLY the valid JSON and nothing else.\n\n"
"The JSON must follow this schema exactly:\n```json\n"
f"{model_schema}\n```"
)
assert instructions == expected_instructions
def test_to_serializable_key():
"""Test serialization of dictionary keys."""
# String and int keys are converted to strings
assert _to_serializable_key("test") == "test"
assert _to_serializable_key(42) == "42"
def test_get_conversion_instructions_non_gpt():
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
with patch.object(LLM, "supports_function_calling", return_value=False):
instructions = get_conversion_instructions(SimpleModel, llm)
assert '"name": str' in instructions
assert '"age": int' in instructions
# Complex objects are converted to a unique string
obj = object()
key_str = _to_serializable_key(obj)
assert isinstance(key_str, str)
assert "key_" in key_str
assert "object" in key_str
# Tests for is_gpt
def test_supports_function_calling_true():
llm = LLM(model="gpt-4o")
assert llm.supports_function_calling() is True
def test_supports_function_calling_false():
llm = LLM(model="non-existent-model")
assert llm.supports_function_calling() is False
def test_create_converter_with_mock_agent():
mock_agent = MagicMock()
mock_agent.get_output_converter.return_value = MagicMock(spec=Converter)
converter = create_converter(
agent=mock_agent,
llm=Mock(),
text="Sample",
model=SimpleModel,
instructions="Convert",
)
assert isinstance(converter, Converter)
mock_agent.get_output_converter.assert_called_once()
def test_create_converter_with_custom_converter():
converter = create_converter(
converter_cls=CustomConverter,
llm=LLM(model="gpt-4o-mini"),
text="Sample",
model=SimpleModel,
instructions="Convert",
)
assert isinstance(converter, CustomConverter)
def test_create_converter_fails_without_agent_or_converter_cls():
with pytest.raises(
ValueError, match="Either agent or converter_cls must be provided"
):
create_converter(
llm=Mock(), text="Sample", model=SimpleModel, instructions="Convert"
)
def test_generate_model_description_simple_model():
description = generate_model_description(SimpleModel)
expected_description = '{\n "name": str,\n "age": int\n}'
assert description == expected_description
def test_generate_model_description_nested_model():
description = generate_model_description(NestedModel)
expected_description = (
'{\n "id": int,\n "data": {\n "name": str,\n "age": int\n}\n}'
)
assert description == expected_description
def test_generate_model_description_optional_field():
class ModelWithOptionalField(BaseModel):
name: Optional[str]
age: int
description = generate_model_description(ModelWithOptionalField)
expected_description = '{\n "name": Optional[str],\n "age": int\n}'
assert description == expected_description
def test_generate_model_description_list_field():
class ModelWithListField(BaseModel):
items: List[int]
description = generate_model_description(ModelWithListField)
expected_description = '{\n "items": List[int]\n}'
assert description == expected_description
def test_generate_model_description_dict_field():
class ModelWithDictField(BaseModel):
attributes: Dict[str, int]
description = generate_model_description(ModelWithDictField)
expected_description = '{\n "attributes": Dict[str, int]\n}'
assert description == expected_description
@pytest.mark.vcr(filter_headers=["authorization"])
def test_convert_with_instructions():
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
# Act
output = converter.to_pydantic()
# Assert
assert isinstance(output, SimpleModel)
assert output.name == "Alice"
assert output.age == 30
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_2_model():
llm = LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_1_model():
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_nested_model():
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
instructions = get_conversion_instructions(Person, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=Person,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, Person)
assert output.name == "John Doe"
assert output.age == 30
assert isinstance(output.address, Address)
assert output.address.street == "123 Main St"
assert output.address.city == "Anytown"
assert output.address.zip_code == "12345"
# Tests for error handling
def test_converter_error_handling():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = "Invalid JSON"
sample_text = "Name: Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
assert "Failed to convert text into a Pydantic model" in str(exc_info.value)
# Tests for retry logic
def test_converter_retry_logic():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.side_effect = [
"Invalid JSON",
"Still invalid",
'{"name": "Retry Alice", "age": 30}',
]
sample_text = "Name: Retry Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
max_attempts=3,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Retry Alice"
assert output.age == 30
assert llm.call.call_count == 3
# Tests for optional fields
def test_converter_with_optional_fields():
class OptionalModel(BaseModel):
name: str
age: Optional[int]
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
# Simulate the LLM's response with 'age' explicitly set to null
llm.call.return_value = '{"name": "Bob", "age": null}'
sample_text = "Name: Bob, age: None"
instructions = get_conversion_instructions(OptionalModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=OptionalModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, OptionalModel)
assert output.name == "Bob"
assert output.age is None
# Tests for list fields
def test_converter_with_list_field():
class ListModel(BaseModel):
items: List[int]
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"items": [1, 2, 3]}'
sample_text = "Items: 1, 2, 3"
instructions = get_conversion_instructions(ListModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=ListModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, ListModel)
assert output.items == [1, 2, 3]
# Tests for enums
from enum import Enum
def test_converter_with_enum():
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
class EnumModel(BaseModel):
name: str
color: Color
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"name": "Alice", "color": "red"}'
sample_text = "Name: Alice, Color: Red"
instructions = get_conversion_instructions(EnumModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=EnumModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, EnumModel)
assert output.name == "Alice"
assert output.color == Color.RED
# Tests for ambiguous input
def test_converter_with_ambiguous_input():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"name": "Charlie", "age": "Not an age"}'
sample_text = "Charlie is thirty years old"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
assert "failed to convert text into a pydantic model" in str(exc_info.value).lower()
# Tests for function calling support
def test_converter_with_function_calling():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = True
instructor = Mock()
instructor.to_pydantic.return_value = SimpleModel(name="Eve", age=35)
converter = Converter(
llm=llm,
text="Name: Eve, Age: 35",
model=SimpleModel,
instructions="Convert this text.",
)
converter._create_instructor = Mock(return_value=instructor)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Eve"
assert output.age == 35
instructor.to_pydantic.assert_called_once()
def test_generate_model_description_union_field():
class UnionModel(BaseModel):
field: int | str | None
description = generate_model_description(UnionModel)
expected_description = '{\n "field": int | str | None\n}'
assert description == expected_description