mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
Compare commits
2 Commits
lorenze/fi
...
devin/1756
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75b7c579f6 | ||
|
|
89fcd2a5b4 |
@@ -1,17 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
@@ -68,7 +66,7 @@ class CrewStructuredTool:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
cls,
|
cls,
|
||||||
func: Callable,
|
func: Callable[..., Any],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
@@ -129,7 +127,7 @@ class CrewStructuredTool:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_schema_from_function(
|
def _create_schema_from_function(
|
||||||
name: str,
|
name: str,
|
||||||
func: Callable,
|
func: Callable[..., Any],
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a Pydantic schema from a function's signature.
|
"""Create a Pydantic schema from a function's signature.
|
||||||
|
|
||||||
@@ -164,7 +162,7 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
schema_name = f"{name.title()}Schema"
|
schema_name = f"{name.title()}Schema"
|
||||||
return create_model(schema_name, **fields)
|
return create_model(schema_name, **fields) # type: ignore[call-overload,no-any-return]
|
||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
@@ -192,7 +190,7 @@ class CrewStructuredTool:
|
|||||||
f"not found in args_schema"
|
f"not found in args_schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Parse and validate the input arguments against the schema.
|
"""Parse and validate the input arguments against the schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -217,8 +215,8 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, dict],
|
input: str | dict[str, Any],
|
||||||
config: Optional[dict] = None,
|
config: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Asynchronously invoke the tool.
|
"""Asynchronously invoke the tool.
|
||||||
@@ -253,7 +251,7 @@ class CrewStructuredTool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _run(self, *args, **kwargs) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Legacy method for compatibility."""
|
"""Legacy method for compatibility."""
|
||||||
# Convert args/kwargs to our expected format
|
# Convert args/kwargs to our expected format
|
||||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
||||||
@@ -261,7 +259,10 @@ class CrewStructuredTool:
|
|||||||
return self.invoke(input_dict)
|
return self.invoke(input_dict)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
self,
|
||||||
|
input: str | dict[str, Any],
|
||||||
|
config: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Main method for tool execution."""
|
"""Main method for tool execution."""
|
||||||
parsed_args = self._parse_args(input)
|
parsed_args = self._parse_args(input)
|
||||||
@@ -273,22 +274,40 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
self._increment_usage_count()
|
self._increment_usage_count()
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self.func):
|
|
||||||
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.func(**parsed_args, **kwargs)
|
if inspect.iscoroutinefunction(self.func):
|
||||||
|
coro = self.func(**parsed_args, **kwargs)
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot call async tool '{self.name}' from synchronous context within an event loop. "
|
||||||
|
f"Use ainvoke() instead or call from outside the event loop."
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "Cannot call async tool" in str(e):
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
else:
|
||||||
|
result = self.func(**parsed_args, **kwargs)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sync function '{self.name}' returned a coroutine but we're in an event loop. "
|
||||||
|
f"Use ainvoke() instead or call from outside the event loop."
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "returned a coroutine but we're in an event loop" in str(e):
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return asyncio.run(result)
|
||||||
|
|
||||||
|
return result
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
result = self.func(**parsed_args, **kwargs)
|
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
return asyncio.run(result)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def has_reached_max_usage_count(self) -> bool:
|
def has_reached_max_usage_count(self) -> bool:
|
||||||
"""Check if the tool has reached its maximum usage count."""
|
"""Check if the tool has reached its maximum usage count."""
|
||||||
return (
|
return (
|
||||||
@@ -303,9 +322,9 @@ class CrewStructuredTool:
|
|||||||
self._original_tool.current_usage_count = self.current_usage_count
|
self._original_tool.current_usage_count = self.current_usage_count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict[str, Any]:
|
||||||
"""Get the tool's input arguments schema."""
|
"""Get the tool's input arguments schema."""
|
||||||
return self.args_schema.model_json_schema()["properties"]
|
return self.args_schema.model_json_schema()["properties"] # type: ignore[no-any-return]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from typing import Optional
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
|
|
||||||
|
|
||||||
# Test fixtures
|
# Test fixtures
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def basic_function():
|
def basic_function() -> Callable[[str, int], str]:
|
||||||
def test_func(param1: str, param2: int = 0) -> str:
|
def test_func(param1: str, param2: int = 0) -> str:
|
||||||
"""Test function with basic params."""
|
"""Test function with basic params."""
|
||||||
return f"{param1} {param2}"
|
return f"{param1} {param2}"
|
||||||
@@ -17,7 +19,7 @@ def basic_function():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def schema_class():
|
def schema_class() -> type[BaseModel]:
|
||||||
class TestSchema(BaseModel):
|
class TestSchema(BaseModel):
|
||||||
param1: str
|
param1: str
|
||||||
param2: int = Field(default=0)
|
param2: int = Field(default=0)
|
||||||
@@ -25,7 +27,9 @@ def schema_class():
|
|||||||
return TestSchema
|
return TestSchema
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(basic_function, schema_class):
|
def test_initialization(
|
||||||
|
basic_function: Callable[[str], str], schema_class: type[BaseModel]
|
||||||
|
) -> None:
|
||||||
"""Test basic initialization of CrewStructuredTool"""
|
"""Test basic initialization of CrewStructuredTool"""
|
||||||
tool = CrewStructuredTool(
|
tool = CrewStructuredTool(
|
||||||
name="test_tool",
|
name="test_tool",
|
||||||
@@ -39,7 +43,8 @@ def test_initialization(basic_function, schema_class):
|
|||||||
assert tool.func == basic_function
|
assert tool.func == basic_function
|
||||||
assert tool.args_schema == schema_class
|
assert tool.args_schema == schema_class
|
||||||
|
|
||||||
def test_from_function(basic_function):
|
|
||||||
|
def test_from_function(basic_function: Callable[[str], str]) -> None:
|
||||||
"""Test creating tool from function"""
|
"""Test creating tool from function"""
|
||||||
tool = CrewStructuredTool.from_function(
|
tool = CrewStructuredTool.from_function(
|
||||||
func=basic_function, name="test_tool", description="Test description"
|
func=basic_function, name="test_tool", description="Test description"
|
||||||
@@ -50,7 +55,10 @@ def test_from_function(basic_function):
|
|||||||
assert tool.func == basic_function
|
assert tool.func == basic_function
|
||||||
assert isinstance(tool.args_schema, type(BaseModel))
|
assert isinstance(tool.args_schema, type(BaseModel))
|
||||||
|
|
||||||
def test_validate_function_signature(basic_function, schema_class):
|
|
||||||
|
def test_validate_function_signature(
|
||||||
|
basic_function: Callable[[str, int], str], schema_class: type[BaseModel]
|
||||||
|
) -> None:
|
||||||
"""Test function signature validation"""
|
"""Test function signature validation"""
|
||||||
tool = CrewStructuredTool(
|
tool = CrewStructuredTool(
|
||||||
name="test_tool",
|
name="test_tool",
|
||||||
@@ -62,15 +70,17 @@ def test_validate_function_signature(basic_function, schema_class):
|
|||||||
# Should not raise any exceptions
|
# Should not raise any exceptions
|
||||||
tool._validate_function_signature()
|
tool._validate_function_signature()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ainvoke(basic_function):
|
async def test_ainvoke(basic_function: Callable[[str, int], str]) -> None:
|
||||||
"""Test asynchronous invocation"""
|
"""Test asynchronous invocation"""
|
||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
result = await tool.ainvoke(input={"param1": "test"})
|
result = await tool.ainvoke(input={"param1": "test"})
|
||||||
assert result == "test 0"
|
assert result == "test 0"
|
||||||
|
|
||||||
def test_parse_args_dict(basic_function):
|
|
||||||
|
def test_parse_args_dict(basic_function: Callable[[str, int], str]) -> None:
|
||||||
"""Test parsing dictionary arguments"""
|
"""Test parsing dictionary arguments"""
|
||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
@@ -78,7 +88,8 @@ def test_parse_args_dict(basic_function):
|
|||||||
assert parsed["param1"] == "test"
|
assert parsed["param1"] == "test"
|
||||||
assert parsed["param2"] == 42
|
assert parsed["param2"] == 42
|
||||||
|
|
||||||
def test_parse_args_string(basic_function):
|
|
||||||
|
def test_parse_args_string(basic_function: Callable[[str, int], str]) -> None:
|
||||||
"""Test parsing string arguments"""
|
"""Test parsing string arguments"""
|
||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
@@ -86,10 +97,11 @@ def test_parse_args_string(basic_function):
|
|||||||
assert parsed["param1"] == "test"
|
assert parsed["param1"] == "test"
|
||||||
assert parsed["param2"] == 42
|
assert parsed["param2"] == 42
|
||||||
|
|
||||||
def test_complex_types():
|
|
||||||
|
def test_complex_types() -> None:
|
||||||
"""Test handling of complex parameter types"""
|
"""Test handling of complex parameter types"""
|
||||||
|
|
||||||
def complex_func(nested: dict, items: list) -> str:
|
def complex_func(nested: dict[str, Any], items: list[Any]) -> str:
|
||||||
"""Process complex types."""
|
"""Process complex types."""
|
||||||
return f"Processed {len(items)} items with {len(nested)} nested keys"
|
return f"Processed {len(items)} items with {len(nested)} nested keys"
|
||||||
|
|
||||||
@@ -99,7 +111,8 @@ def test_complex_types():
|
|||||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||||
assert result == "Processed 3 items with 1 nested keys"
|
assert result == "Processed 3 items with 1 nested keys"
|
||||||
|
|
||||||
def test_schema_inheritance():
|
|
||||||
|
def test_schema_inheritance() -> None:
|
||||||
"""Test tool creation with inherited schema"""
|
"""Test tool creation with inherited schema"""
|
||||||
|
|
||||||
def extended_func(base_param: str, extra_param: int) -> str:
|
def extended_func(base_param: str, extra_param: int) -> str:
|
||||||
@@ -119,7 +132,8 @@ def test_schema_inheritance():
|
|||||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||||
assert result == "test 42"
|
assert result == "test 42"
|
||||||
|
|
||||||
def test_default_values_in_schema():
|
|
||||||
|
def test_default_values_in_schema() -> None:
|
||||||
"""Test handling of default values in schema"""
|
"""Test handling of default values in schema"""
|
||||||
|
|
||||||
def default_func(
|
def default_func(
|
||||||
@@ -144,19 +158,21 @@ def test_default_values_in_schema():
|
|||||||
)
|
)
|
||||||
assert result == "test custom 42"
|
assert result == "test custom 42"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def custom_tool_decorator():
|
def custom_tool_decorator() -> Any:
|
||||||
from crewai.tools import tool
|
from crewai.tools import tool
|
||||||
|
|
||||||
@tool("custom_tool", result_as_answer=True)
|
@tool("custom_tool", result_as_answer=True)
|
||||||
async def custom_tool():
|
async def custom_tool() -> str:
|
||||||
"""This is a tool that does something"""
|
"""This is a tool that does something"""
|
||||||
return "Hello World from Custom Tool"
|
return "Hello World from Custom Tool"
|
||||||
|
|
||||||
return custom_tool
|
return custom_tool
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def custom_tool():
|
def custom_tool() -> BaseTool:
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
class CustomTool(BaseTool):
|
class CustomTool(BaseTool):
|
||||||
@@ -164,46 +180,57 @@ def custom_tool():
|
|||||||
description: str = "This is a tool that does something"
|
description: str = "This is a tool that does something"
|
||||||
result_as_answer: bool = True
|
result_as_answer: bool = True
|
||||||
|
|
||||||
async def _run(self):
|
async def _run(self) -> str:
|
||||||
return "Hello World from Custom Tool"
|
return "Hello World from Custom Tool"
|
||||||
|
|
||||||
return CustomTool()
|
return CustomTool()
|
||||||
|
|
||||||
def build_simple_crew(tool):
|
|
||||||
from crewai import Agent, Task, Crew
|
|
||||||
|
|
||||||
agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool])
|
def build_simple_crew(tool: Any) -> Any:
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
|
||||||
|
agent1 = Agent(
|
||||||
|
role="Simple role",
|
||||||
|
goal="Simple goal",
|
||||||
|
backstory="Simple backstory",
|
||||||
|
tools=[tool],
|
||||||
|
)
|
||||||
|
|
||||||
say_hi_task = Task(
|
say_hi_task = Task(
|
||||||
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
description="Use the custom tool result as answer.",
|
||||||
|
agent=agent1,
|
||||||
|
expected_output="Use the tool result",
|
||||||
)
|
)
|
||||||
|
|
||||||
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
||||||
return crew
|
return crew
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
def test_async_tool_using_within_isolated_crew(custom_tool: BaseTool) -> None:
|
||||||
crew = build_simple_crew(custom_tool)
|
crew = build_simple_crew(custom_tool)
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
def test_async_tool_using_decorator_within_isolated_crew(
|
||||||
|
custom_tool_decorator: Any,
|
||||||
|
) -> None:
|
||||||
crew = build_simple_crew(custom_tool_decorator)
|
crew = build_simple_crew(custom_tool_decorator)
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_within_flow(custom_tool):
|
def test_async_tool_within_flow(custom_tool: BaseTool) -> None:
|
||||||
from crewai.flow.flow import Flow
|
from crewai.flow.flow import Flow, start
|
||||||
|
|
||||||
class StructuredExampleFlow(Flow):
|
|
||||||
from crewai.flow.flow import start
|
|
||||||
|
|
||||||
|
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
|
||||||
@start()
|
@start()
|
||||||
async def start(self):
|
async def start(self) -> Any:
|
||||||
crew = build_simple_crew(custom_tool)
|
crew = build_simple_crew(custom_tool)
|
||||||
result = await crew.kickoff_async()
|
result = await crew.kickoff_async()
|
||||||
return result
|
return result
|
||||||
@@ -214,17 +241,110 @@ def test_async_tool_within_flow(custom_tool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
def test_async_tool_using_decorator_within_flow(custom_tool_decorator: Any) -> None:
|
||||||
from crewai.flow.flow import Flow
|
from crewai.flow.flow import Flow, start
|
||||||
|
|
||||||
class StructuredExampleFlow(Flow):
|
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
|
||||||
from crewai.flow.flow import start
|
|
||||||
@start()
|
@start()
|
||||||
async def start(self):
|
async def start(self) -> Any:
|
||||||
crew = build_simple_crew(custom_tool_decorator)
|
crew = build_simple_crew(custom_tool_decorator)
|
||||||
result = await crew.kickoff_async()
|
result = await crew.kickoff_async()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
flow = StructuredExampleFlow()
|
flow = StructuredExampleFlow()
|
||||||
result = flow.kickoff()
|
result = flow.kickoff()
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_sync_function_single_execution() -> None:
|
||||||
|
"""Test that sync functions are called only once, not twice."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def counting_func(message: str) -> str:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return f"Called {call_count} times with: {message}"
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=counting_func, name="counting_tool", description="A tool that counts calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tool.invoke({"message": "test"})
|
||||||
|
assert call_count == 1, f"Function was called {call_count} times, expected 1"
|
||||||
|
assert result == "Called 1 times with: test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_async_function_outside_event_loop() -> None:
|
||||||
|
"""Test that async functions work correctly when called outside event loop."""
|
||||||
|
|
||||||
|
async def async_func(message: str) -> str:
|
||||||
|
return f"Async result: {message}"
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=async_func, name="async_tool", description="An async tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tool.invoke({"message": "test"})
|
||||||
|
assert result == "Async result: test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_async_function_in_event_loop_raises_error() -> None:
|
||||||
|
"""Test that async functions raise RuntimeError when called from within event loop."""
|
||||||
|
|
||||||
|
async def async_func(message: str) -> str:
|
||||||
|
return f"Async result: {message}"
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=async_func, name="async_tool", description="An async tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Cannot call async tool.*from synchronous context within an event loop",
|
||||||
|
):
|
||||||
|
tool.invoke({"message": "test"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_sync_function_returning_coroutine() -> None:
|
||||||
|
"""Test handling of sync functions that return coroutines."""
|
||||||
|
|
||||||
|
async def inner_async(message: str) -> str:
|
||||||
|
return f"Inner async: {message}"
|
||||||
|
|
||||||
|
def sync_func_returning_coro(message: str) -> Any:
|
||||||
|
return inner_async(message)
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=sync_func_returning_coro,
|
||||||
|
name="sync_coro_tool",
|
||||||
|
description="A sync tool that returns coroutine",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tool.invoke({"message": "test"})
|
||||||
|
assert result == "Inner async: test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
|
"""Test that sync functions returning coroutines raise RuntimeError in event loop."""
|
||||||
|
|
||||||
|
async def inner_async(message: str) -> str:
|
||||||
|
return f"Inner async: {message}"
|
||||||
|
|
||||||
|
def sync_func_returning_coro(message: str) -> Any:
|
||||||
|
return inner_async(message)
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=sync_func_returning_coro,
|
||||||
|
name="sync_coro_tool",
|
||||||
|
description="A sync tool that returns coroutine",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Sync function.*returned a coroutine but we're in an event loop",
|
||||||
|
):
|
||||||
|
tool.invoke({"message": "test"})
|
||||||
|
|||||||
Reference in New Issue
Block a user