Apply automatic linting fixes to tests directory

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-12 13:31:07 +00:00
parent ad1ea46bbb
commit 46621113af
62 changed files with 1738 additions and 1821 deletions

View File

@@ -22,7 +22,7 @@ from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
def test_agent_llm_creation_with_env_vars():
def test_agent_llm_creation_with_env_vars() -> None:
# Store original environment variables
original_api_key = os.environ.get("OPENAI_API_KEY")
original_api_base = os.environ.get("OPENAI_API_BASE")
@@ -65,7 +65,7 @@ def test_agent_llm_creation_with_env_vars():
os.environ["OPENAI_MODEL_NAME"] = original_model_name
def test_agent_creation():
def test_agent_creation() -> None:
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
assert agent.role == "test role"
@@ -73,7 +73,7 @@ def test_agent_creation():
assert agent.backstory == "test backstory"
def test_agent_with_only_system_template():
def test_agent_with_only_system_template() -> None:
"""Test that an agent with only system_template works without errors."""
agent = Agent(
role="Test Role",
@@ -89,7 +89,7 @@ def test_agent_with_only_system_template():
assert agent.backstory == "Test Backstory"
def test_agent_with_only_prompt_template():
def test_agent_with_only_prompt_template() -> None:
"""Test that an agent with only system_template works without errors."""
agent = Agent(
role="Test Role",
@@ -105,7 +105,7 @@ def test_agent_with_only_prompt_template():
assert agent.backstory == "Test Backstory"
def test_agent_with_missing_response_template():
def test_agent_with_missing_response_template() -> None:
"""Test that an agent with system_template and prompt_template but no response_template works without errors."""
agent = Agent(
role="Test Role",
@@ -122,20 +122,20 @@ def test_agent_with_missing_response_template():
assert agent.backstory == "Test Backstory"
def test_agent_default_values():
def test_agent_default_values() -> None:
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
assert agent.llm.model == "gpt-4o-mini"
assert agent.allow_delegation is False
def test_custom_llm():
def test_custom_llm() -> None:
agent = Agent(
role="test role", goal="test goal", backstory="test backstory", llm="gpt-4"
role="test role", goal="test goal", backstory="test backstory", llm="gpt-4",
)
assert agent.llm.model == "gpt-4"
def test_custom_llm_with_langchain():
def test_custom_llm_with_langchain() -> None:
from langchain_openai import ChatOpenAI
agent = Agent(
@@ -148,7 +148,7 @@ def test_custom_llm_with_langchain():
assert agent.llm.model == "gpt-4"
def test_custom_llm_temperature_preservation():
def test_custom_llm_temperature_preservation() -> None:
from langchain_openai import ChatOpenAI
langchain_llm = ChatOpenAI(temperature=0.7, model="gpt-4")
@@ -165,7 +165,7 @@ def test_custom_llm_temperature_preservation():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution():
def test_agent_execution() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -184,7 +184,7 @@ def test_agent_execution():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_tools():
def test_agent_execution_with_tools() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -206,7 +206,7 @@ def test_agent_execution_with_tools():
received_events = []
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
def handle_tool_end(source, event) -> None:
received_events.append(event)
output = agent.execute_task(task)
@@ -219,7 +219,7 @@ def test_agent_execution_with_tools():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_logging_tool_usage():
def test_logging_tool_usage() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -244,7 +244,7 @@ def test_logging_tool_usage():
agent.tools_handler.cache = CacheHandler()
output = agent.execute_task(task)
tool_usage = InstructorToolCalling(
tool_name=multiplier.name, arguments={"first_number": 3, "second_number": 4}
tool_name=multiplier.name, arguments={"first_number": 3, "second_number": 4},
)
assert output == "The result of the multiplication is 12."
@@ -253,7 +253,7 @@ def test_logging_tool_usage():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_cache_hitting():
def test_cache_hitting() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -305,7 +305,7 @@ def test_cache_hitting():
received_events = []
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
def handle_tool_end(source, event) -> None:
received_events.append(event)
with (
@@ -320,7 +320,7 @@ def test_cache_hitting():
output = agent.execute_task(task)
assert output == "0"
read.assert_called_with(
tool="multiplier", input={"first_number": 2, "second_number": 6}
tool="multiplier", input={"first_number": 2, "second_number": 6},
)
assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent)
@@ -328,7 +328,7 @@ def test_cache_hitting():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_disabling_cache_for_agent():
def test_disabling_cache_for_agent() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -392,7 +392,7 @@ def test_disabling_cache_for_agent():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_specific_tools():
def test_agent_execution_with_specific_tools() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -415,7 +415,7 @@ def test_agent_execution_with_specific_tools():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool() -> None:
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
@@ -441,7 +441,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_powered_by_new_o_model_family_that_uses_tool():
def test_agent_powered_by_new_o_model_family_that_uses_tool() -> None:
@tool
def comapny_customer_data() -> float:
"""Useful for getting customer related data."""
@@ -467,11 +467,12 @@ def test_agent_powered_by_new_o_model_family_that_uses_tool():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_custom_max_iterations():
def test_agent_custom_max_iterations() -> None:
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -483,7 +484,7 @@ def test_agent_custom_max_iterations():
)
with patch.object(
LLM, "call", wraps=LLM("gpt-4o", stop=["\nObservation:"]).call
LLM, "call", wraps=LLM("gpt-4o", stop=["\nObservation:"]).call,
) as private_mock:
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
@@ -497,11 +498,12 @@ def test_agent_custom_max_iterations():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage(capsys):
def test_agent_repeated_tool_usage(capsys) -> None:
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -534,11 +536,12 @@ def test_agent_repeated_tool_usage(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys) -> None:
@tool
def get_final_answer(anything: str) -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -570,11 +573,12 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_moved_on_after_max_iterations():
def test_agent_moved_on_after_max_iterations() -> None:
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -597,11 +601,12 @@ def test_agent_moved_on_after_max_iterations():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_respect_the_max_rpm_set(capsys):
def test_agent_respect_the_max_rpm_set(capsys) -> None:
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -631,7 +636,7 @@ def test_agent_respect_the_max_rpm_set(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys) -> None:
from unittest.mock import patch
from crewai.tools import tool
@@ -639,7 +644,8 @@ def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent = Agent(
@@ -669,7 +675,7 @@ def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_without_max_rpm_respects_crew_rpm(capsys):
def test_agent_without_max_rpm_respects_crew_rpm(capsys) -> None:
from unittest.mock import patch
from crewai.tools import tool
@@ -729,7 +735,7 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_error_on_parsing_tool(capsys):
def test_agent_error_on_parsing_tool(capsys) -> None:
from unittest.mock import patch
from crewai.tools import tool
@@ -737,7 +743,8 @@ def test_agent_error_on_parsing_tool(capsys):
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent1 = Agent(
@@ -753,7 +760,7 @@ def test_agent_error_on_parsing_tool(capsys):
expected_output="The final answer",
agent=agent1,
tools=[get_final_answer],
)
),
]
crew = Crew(
@@ -772,7 +779,7 @@ def test_agent_error_on_parsing_tool(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_remembers_output_format_after_using_tools_too_many_times():
def test_agent_remembers_output_format_after_using_tools_too_many_times() -> None:
from unittest.mock import patch
from crewai.tools import tool
@@ -780,7 +787,8 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
tool non-stop.
"""
return 42
agent1 = Agent(
@@ -796,7 +804,7 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
expected_output="The final answer",
agent=agent1,
tools=[get_final_answer],
)
),
]
crew = Crew(agents=[agent1], tasks=tasks, verbose=True)
@@ -807,15 +815,15 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_use_specific_tasks_output_as_context(capsys):
def test_agent_use_specific_tasks_output_as_context(capsys) -> None:
agent1 = Agent(role="test role", goal="test goal", backstory="test backstory")
agent2 = Agent(role="test role2", goal="test goal2", backstory="test backstory2")
say_hi_task = Task(
description="Just say hi.", agent=agent1, expected_output="Your greeting."
description="Just say hi.", agent=agent1, expected_output="Your greeting.",
)
say_bye_task = Task(
description="Just say bye.", agent=agent1, expected_output="Your farewell."
description="Just say bye.", agent=agent1, expected_output="Your farewell.",
)
answer_task = Task(
description="Answer accordingly to the context you got.",
@@ -834,9 +842,9 @@ def test_agent_use_specific_tasks_output_as_context(capsys):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_step_callback():
def test_agent_step_callback() -> None:
class StepCallback:
def callback(self, step):
def callback(self, step) -> None:
pass
with patch.object(StepCallback, "callback") as callback:
@@ -868,7 +876,7 @@ def test_agent_step_callback():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_function_calling_llm():
def test_agent_function_calling_llm() -> None:
llm = "gpt-4o"
@tool
@@ -901,7 +909,7 @@ def test_agent_function_calling_llm():
with (
patch.object(
instructor, "from_litellm", wraps=instructor.from_litellm
instructor, "from_litellm", wraps=instructor.from_litellm,
) as mock_from_litellm,
patch.object(
ToolUsage,
@@ -915,7 +923,7 @@ def test_agent_function_calling_llm():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
def test_tool_result_as_answer_is_the_final_answer_for_the_agent() -> None:
from crewai.tools import BaseTool
class MyCustomTool(BaseTool):
@@ -945,7 +953,7 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tool_usage_information_is_appended_to_agent():
def test_tool_usage_information_is_appended_to_agent() -> None:
from crewai.tools import BaseTool
class MyCustomTool(BaseTool):
@@ -977,11 +985,11 @@ def test_tool_usage_information_is_appended_to_agent():
"tool_name": "Decide Greetings",
"tool_args": {},
"result_as_answer": True,
}
},
]
def test_agent_definition_based_on_dict():
def test_agent_definition_based_on_dict() -> None:
config = {
"role": "test role",
"goal": "test goal",
@@ -1000,7 +1008,7 @@ def test_agent_definition_based_on_dict():
# test for human input
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_human_input():
def test_agent_human_input() -> None:
# Agent configuration
config = {
"role": "test role",
@@ -1023,7 +1031,7 @@ def test_agent_human_input():
[
"Don't say hi, say Hello instead!", # First feedback: instruct change
"", # Second feedback: empty string signals acceptance
]
],
)
def ask_human_input_side_effect(*args, **kwargs):
@@ -1040,7 +1048,7 @@ def test_agent_human_input():
CrewAgentExecutor,
"_invoke_loop",
return_value=AgentFinish(output="Hello", thought="", text=""),
) as mock_invoke_loop,
),
):
# Execute the task
output = agent.execute_task(task)
@@ -1052,7 +1060,7 @@ def test_agent_human_input():
assert output.strip().lower() == "hello"
def test_interpolate_inputs():
def test_interpolate_inputs() -> None:
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
@@ -1070,7 +1078,7 @@ def test_interpolate_inputs():
assert agent.backstory == "I am the master of nothing"
def test_not_using_system_prompt():
def test_not_using_system_prompt() -> None:
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
@@ -1083,7 +1091,7 @@ def test_not_using_system_prompt():
assert not agent.agent_executor.prompt.get("system")
def test_using_system_prompt():
def test_using_system_prompt() -> None:
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
@@ -1095,7 +1103,7 @@ def test_using_system_prompt():
assert agent.agent_executor.prompt.get("system")
def test_system_and_prompt_template():
def test_system_and_prompt_template() -> None:
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
@@ -1148,7 +1156,7 @@ Thought:<|eot_id|>
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_training_handler(crew_training_handler):
def test_agent_training_handler(crew_training_handler) -> None:
task_prompt = "What is 1 + 1?"
agent = Agent(
role="test role",
@@ -1157,7 +1165,7 @@ def test_agent_training_handler(crew_training_handler):
verbose=True,
)
crew_training_handler().load.return_value = {
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
f"{agent.id!s}": {"0": {"human_feedback": "good"}},
}
result = agent._training_handler(task_prompt=task_prompt)
@@ -1165,12 +1173,12 @@ def test_agent_training_handler(crew_training_handler):
assert result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n good"
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("training_data.pkl"), mock.call().load()]
[mock.call(), mock.call("training_data.pkl"), mock.call().load()],
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_use_trained_data(crew_training_handler):
def test_agent_use_trained_data(crew_training_handler) -> None:
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
@@ -1183,8 +1191,8 @@ def test_agent_use_trained_data(crew_training_handler):
"suggestions": [
"The result of the math operation must be right.",
"Result must be better than 1.",
]
}
],
},
}
result = agent._use_trained_data(task_prompt=task_prompt)
@@ -1194,11 +1202,11 @@ def test_agent_use_trained_data(crew_training_handler):
" - The result of the math operation must be right.\n - Result must be better than 1."
)
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()],
)
def test_agent_max_retry_limit():
def test_agent_max_retry_limit() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1215,7 +1223,7 @@ def test_agent_max_retry_limit():
error_message = "Error happening while sending prompt to model."
with patch.object(
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke,
) as invoke_mock:
invoke_mock.side_effect = Exception(error_message)
@@ -1237,7 +1245,7 @@ def test_agent_max_retry_limit():
"tool_names": "",
"tools": "",
"ask_for_human_input": True,
}
},
),
mock.call(
{
@@ -1245,13 +1253,13 @@ def test_agent_max_retry_limit():
"tool_names": "",
"tools": "",
"ask_for_human_input": True,
}
},
),
]
],
)
def test_agent_with_llm():
def test_agent_with_llm() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1264,7 +1272,7 @@ def test_agent_with_llm():
assert agent.llm.temperature == 0.7
def test_agent_with_custom_stop_words():
def test_agent_with_custom_stop_words() -> None:
stop_words = ["STOP", "END"]
agent = Agent(
role="test role",
@@ -1274,13 +1282,13 @@ def test_agent_with_custom_stop_words():
)
assert isinstance(agent.llm, LLM)
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
assert set(agent.llm.stop) == {*stop_words, "\nObservation:"}
assert all(word in agent.llm.stop for word in stop_words)
assert "\nObservation:" in agent.llm.stop
def test_agent_with_callbacks():
def dummy_callback(response):
def test_agent_with_callbacks() -> None:
def dummy_callback(response) -> None:
pass
agent = Agent(
@@ -1295,7 +1303,7 @@ def test_agent_with_callbacks():
assert agent.llm.callbacks[0] == dummy_callback
def test_agent_with_additional_kwargs():
def test_agent_with_additional_kwargs() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1318,7 +1326,7 @@ def test_agent_with_additional_kwargs():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call():
def test_llm_call() -> None:
llm = LLM(model="gpt-3.5-turbo")
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
@@ -1327,7 +1335,7 @@ def test_llm_call():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_error():
def test_llm_call_with_error() -> None:
llm = LLM(model="non-existent-model")
messages = [{"role": "user", "content": "This should fail"}]
@@ -1336,7 +1344,7 @@ def test_llm_call_with_error():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit():
def test_handle_context_length_exceeds_limit() -> None:
# Import necessary modules
from crewai.utilities.agent_utils import handle_context_length
from crewai.utilities.i18n import I18N
@@ -1361,7 +1369,7 @@ def test_handle_context_length_exceeds_limit():
{
"role": "user",
"content": "This is a test message that would exceed context length",
}
},
]
# Set up test parameters
@@ -1389,7 +1397,7 @@ def test_handle_context_length_exceeds_limit():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit_cli_no():
def test_handle_context_length_exceeds_limit_cli_no() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1399,7 +1407,7 @@ def test_handle_context_length_exceeds_limit_cli_no():
task = Task(description="test task", agent=agent, expected_output="test output")
with patch.object(
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke,
) as private_mock:
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
@@ -1411,12 +1419,12 @@ def test_handle_context_length_exceeds_limit_cli_no():
private_mock.assert_called_once()
pytest.raises(SystemExit)
with patch(
"crewai.utilities.agent_utils.handle_context_length"
"crewai.utilities.agent_utils.handle_context_length",
) as mock_handle_context:
mock_handle_context.assert_not_called()
def test_agent_with_all_llm_attributes():
def test_agent_with_all_llm_attributes() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1448,7 +1456,7 @@ def test_agent_with_all_llm_attributes():
assert agent.llm.temperature == 0.7
assert agent.llm.top_p == 0.9
assert agent.llm.n == 1
assert set(agent.llm.stop) == set(["STOP", "END", "\nObservation:"])
assert set(agent.llm.stop) == {"STOP", "END", "\nObservation:"}
assert all(word in agent.llm.stop for word in ["STOP", "END", "\nObservation:"])
assert agent.llm.max_tokens == 100
assert agent.llm.presence_penalty == 0.1
@@ -1464,7 +1472,7 @@ def test_agent_with_all_llm_attributes():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_all_attributes():
def test_llm_call_with_all_attributes() -> None:
llm = LLM(
model="gpt-3.5-turbo",
temperature=0.7,
@@ -1481,7 +1489,7 @@ def test_llm_call_with_all_attributes():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_ollama_llama3():
def test_agent_with_ollama_llama3() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1502,7 +1510,7 @@ def test_agent_with_ollama_llama3():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_ollama_llama3():
def test_llm_call_with_ollama_llama3() -> None:
llm = LLM(
model="ollama/llama3.2:3b",
base_url="http://localhost:11434",
@@ -1510,7 +1518,7 @@ def test_llm_call_with_ollama_llama3():
max_tokens=30,
)
messages = [
{"role": "user", "content": "Respond in 20 words. Which model are you?"}
{"role": "user", "content": "Respond in 20 words. Which model are you?"},
]
response = llm.call(messages)
@@ -1521,7 +1529,7 @@ def test_llm_call_with_ollama_llama3():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execute_task_basic():
def test_agent_execute_task_basic() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1540,7 +1548,7 @@ def test_agent_execute_task_basic():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execute_task_with_context():
def test_agent_execute_task_with_context() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1558,11 +1566,12 @@ def test_agent_execute_task_with_context():
result = agent.execute_task(task, context=context)
assert len(result.split(".")) == 3
assert "fox" in result.lower() and "dog" in result.lower()
assert "fox" in result.lower()
assert "dog" in result.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execute_task_with_tool():
def test_agent_execute_task_with_tool() -> None:
@tool
def dummy_tool(query: str) -> str:
"""Useful for when you need to get a dummy result for a query."""
@@ -1587,7 +1596,7 @@ def test_agent_execute_task_with_tool():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execute_task_with_custom_llm():
def test_agent_execute_task_with_custom_llm() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1603,12 +1612,12 @@ def test_agent_execute_task_with_custom_llm():
result = agent.execute_task(task)
assert result.startswith(
"Artificial minds,\nCoding thoughts in circuits bright,\nAI's silent might."
"Artificial minds,\nCoding thoughts in circuits bright,\nAI's silent might.",
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execute_task_with_ollama():
def test_agent_execute_task_with_ollama() -> None:
agent = Agent(
role="test role",
goal="test goal",
@@ -1628,7 +1637,7 @@ def test_agent_execute_task_with_ollama():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources():
def test_agent_with_knowledge_sources() -> None:
# Create a knowledge source with some content
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1660,12 +1669,12 @@ def test_agent_with_knowledge_sources():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold() -> None:
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage",
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
@@ -1695,12 +1704,12 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default() -> None:
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage",
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
@@ -1732,7 +1741,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_extensive_role():
def test_agent_with_knowledge_sources_extensive_role() -> None:
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1762,7 +1771,7 @@ def test_agent_with_knowledge_sources_extensive_role():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_works_with_copy():
def test_agent_with_knowledge_sources_works_with_copy() -> None:
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1783,7 +1792,7 @@ def test_agent_with_knowledge_sources_works_with_copy():
)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage",
) as MockKnowledgeStorage:
mock_knowledge_storage = MockKnowledgeStorage.return_value
agent.knowledge_storage = mock_knowledge_storage
@@ -1801,7 +1810,7 @@ def test_agent_with_knowledge_sources_works_with_copy():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_generate_search_query():
def test_agent_with_knowledge_sources_generate_search_query() -> None:
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1835,7 +1844,7 @@ def test_agent_with_knowledge_sources_generate_search_query():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_litellm_auth_error_handling():
def test_litellm_auth_error_handling() -> None:
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
@@ -1861,7 +1870,7 @@ def test_litellm_auth_error_handling():
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key", llm_provider="openai", model="gpt-4"
message="Invalid API key", llm_provider="openai", model="gpt-4",
)
agent.execute_task(task)
@@ -1869,7 +1878,7 @@ def test_litellm_auth_error_handling():
mock_llm_call.assert_called_once()
def test_crew_agent_executor_litellm_auth_error():
def test_crew_agent_executor_litellm_auth_error() -> None:
"""Test that CrewAgentExecutor handles LiteLLM authentication errors by raising them."""
from litellm.exceptions import AuthenticationError
@@ -1911,18 +1920,18 @@ def test_crew_agent_executor_litellm_auth_error():
pytest.raises(AuthenticationError) as exc_info,
):
mock_llm_call.side_effect = AuthenticationError(
message="Invalid API key", llm_provider="openai", model="gpt-4"
message="Invalid API key", llm_provider="openai", model="gpt-4",
)
executor.invoke(
{
"input": "test input",
"tool_names": "",
"tools": "",
}
},
)
# Verify error handling messages
error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}"
error_message = f"Error during LLM call: {mock_llm_call.side_effect!s}"
mock_printer.assert_any_call(
content=error_message,
color="red",
@@ -1938,7 +1947,7 @@ def test_crew_agent_executor_litellm_auth_error():
assert exc_info.value.model == "gpt-4"
def test_litellm_anthropic_error_handling():
def test_litellm_anthropic_error_handling() -> None:
"""Test that AnthropicError from LiteLLM is handled correctly and not retried."""
from litellm.llms.anthropic.common_utils import AnthropicError
@@ -1974,7 +1983,7 @@ def test_litellm_anthropic_error_handling():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_get_knowledge_search_query():
def test_get_knowledge_search_query() -> None:
"""Test that _get_knowledge_search_query calls the LLM with the correct prompts."""
from crewai.utilities.i18n import I18N
@@ -2014,14 +2023,14 @@ def test_get_knowledge_search_query():
{
"role": "system",
"content": i18n.slice(
"knowledge_search_query_system_prompt"
"knowledge_search_query_system_prompt",
).format(task_prompt=task.description),
},
{
"role": "user",
"content": i18n.slice("knowledge_search_query").format(
task_prompt=task_prompt
task_prompt=task_prompt,
),
},
]
],
)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any
import pytest
from pydantic import BaseModel
@@ -12,7 +12,7 @@ from crewai.utilities.token_counter_callback import TokenProcess
# Concrete implementation for testing
class ConcreteAgentAdapter(BaseAgentAdapter):
def configure_tools(
self, tools: Optional[List[BaseTool]] = None, **kwargs: Any
self, tools: list[BaseTool] | None = None, **kwargs: Any,
) -> None:
# Simple implementation for testing
self.tools = tools or []
@@ -20,35 +20,35 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[Any]] = None,
context: str | None = None,
tools: list[Any] | None = None,
) -> str:
# Dummy implementation needed due to BaseAgent inheritance
return "Task executed"
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> Any:
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> Any:
# Dummy implementation
return None
def get_delegation_tools(
self, tools: List[BaseTool], tool_map: Optional[Dict[str, BaseTool]]
) -> List[BaseTool]:
self, tools: list[BaseTool], tool_map: dict[str, BaseTool] | None,
) -> list[BaseTool]:
# Dummy implementation
return []
def _parse_output(self, agent_output: Any, token_process: TokenProcess):
def _parse_output(self, agent_output: Any, token_process: TokenProcess) -> None:
# Dummy implementation
pass
def get_output_converter(self, tools: Optional[List[BaseTool]] = None) -> Any:
def get_output_converter(self, tools: list[BaseTool] | None = None) -> Any:
# Dummy implementation
return None
def test_base_agent_adapter_initialization():
def test_base_agent_adapter_initialization() -> None:
"""Test initialization of the concrete agent adapter."""
adapter = ConcreteAgentAdapter(
role="test role", goal="test goal", backstory="test backstory"
role="test role", goal="test goal", backstory="test backstory",
)
assert isinstance(adapter, BaseAgent)
assert isinstance(adapter, BaseAgentAdapter)
@@ -57,7 +57,7 @@ def test_base_agent_adapter_initialization():
assert adapter.adapted_structured_output is False
def test_base_agent_adapter_initialization_with_config():
def test_base_agent_adapter_initialization_with_config() -> None:
"""Test initialization with agent_config."""
config = {"model": "gpt-4"}
adapter = ConcreteAgentAdapter(
@@ -69,10 +69,10 @@ def test_base_agent_adapter_initialization_with_config():
assert adapter._agent_config == config
def test_configure_tools_method_exists():
def test_configure_tools_method_exists() -> None:
"""Test that configure_tools method exists and can be called."""
adapter = ConcreteAgentAdapter(
role="test role", goal="test goal", backstory="test backstory"
role="test role", goal="test goal", backstory="test backstory",
)
# Create dummy tools if needed, or pass None
tools = []
@@ -81,10 +81,10 @@ def test_configure_tools_method_exists():
assert adapter.tools == tools
def test_configure_structured_output_method_exists():
def test_configure_structured_output_method_exists() -> None:
"""Test that configure_structured_output method exists and can be called."""
adapter = ConcreteAgentAdapter(
role="test role", goal="test goal", backstory="test backstory"
role="test role", goal="test goal", backstory="test backstory",
)
# Define a dummy structure or pass None/Any
@@ -95,10 +95,9 @@ def test_configure_structured_output_method_exists():
adapter.configure_structured_output(structured_output)
# Add assertions here if configure_structured_output modifies state
# For now, just ensuring it runs without error is sufficient
pass
def test_base_agent_adapter_inherits_base_agent():
def test_base_agent_adapter_inherits_base_agent() -> None:
"""Test that BaseAgentAdapter inherits from BaseAgent."""
assert issubclass(BaseAgentAdapter, BaseAgent)
@@ -107,7 +106,7 @@ class ConcreteAgentAdapterWithoutRequiredMethods(BaseAgentAdapter):
pass
def test_base_agent_adapter_fails_without_required_methods():
def test_base_agent_adapter_fails_without_required_methods() -> None:
"""Test that BaseAgentAdapter fails without required methods."""
with pytest.raises(TypeError):
ConcreteAgentAdapterWithoutRequiredMethods() # type: ignore

View File

@@ -1,4 +1,3 @@
from typing import Any, List
from unittest.mock import Mock
import pytest
@@ -8,7 +7,7 @@ from crewai.tools.base_tool import BaseTool
class ConcreteToolAdapter(BaseToolAdapter):
def configure_tools(self, tools: List[BaseTool]) -> None:
def configure_tools(self, tools: list[BaseTool]) -> None:
self.converted_tools = [f"converted_{tool.name}" for tool in tools]
@@ -31,19 +30,19 @@ def tools_list(mock_tool_1, mock_tool_2):
return [mock_tool_1, mock_tool_2]
def test_initialization_with_tools(tools_list):
def test_initialization_with_tools(tools_list) -> None:
adapter = ConcreteToolAdapter(tools=tools_list)
assert adapter.original_tools == tools_list
assert adapter.converted_tools == [] # Conversion happens in configure_tools
def test_initialization_without_tools():
def test_initialization_without_tools() -> None:
adapter = ConcreteToolAdapter()
assert adapter.original_tools == []
assert adapter.converted_tools == []
def test_configure_tools(tools_list):
def test_configure_tools(tools_list) -> None:
adapter = ConcreteToolAdapter()
adapter.configure_tools(tools_list)
assert adapter.converted_tools == ["converted_Mock Tool 1", "converted_MockTool2"]
@@ -58,28 +57,28 @@ def test_configure_tools(tools_list):
assert adapter_with_init_tools.original_tools == tools_list
def test_tools_method(tools_list):
def test_tools_method(tools_list) -> None:
adapter = ConcreteToolAdapter()
adapter.configure_tools(tools_list)
assert adapter.tools() == ["converted_Mock Tool 1", "converted_MockTool2"]
def test_tools_method_empty():
def test_tools_method_empty() -> None:
adapter = ConcreteToolAdapter()
assert adapter.tools() == []
def test_sanitize_tool_name_with_spaces():
def test_sanitize_tool_name_with_spaces() -> None:
adapter = ConcreteToolAdapter()
assert adapter.sanitize_tool_name("Tool With Spaces") == "Tool_With_Spaces"
def test_sanitize_tool_name_without_spaces():
def test_sanitize_tool_name_without_spaces() -> None:
adapter = ConcreteToolAdapter()
assert adapter.sanitize_tool_name("ToolWithoutSpaces") == "ToolWithoutSpaces"
def test_sanitize_tool_name_empty():
def test_sanitize_tool_name_empty() -> None:
adapter = ConcreteToolAdapter()
assert adapter.sanitize_tool_name("") == ""
@@ -88,7 +87,7 @@ class ConcreteToolAdapterWithoutRequiredMethods(BaseToolAdapter):
pass
def test_tool_adapted_fails_without_required_methods():
def test_tool_adapted_fails_without_required_methods() -> None:
"""Test that BaseToolAdapter fails without required methods."""
with pytest.raises(TypeError):
ConcreteToolAdapterWithoutRequiredMethods() # type: ignore

View File

@@ -1,5 +1,5 @@
import hashlib
from typing import Any, List, Optional
from typing import Any
from pydantic import BaseModel
@@ -11,25 +11,25 @@ class MockAgent(BaseAgent):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
return ""
def create_agent_executor(self, tools=None) -> None: ...
def get_delegation_tools(self, agents: List["BaseAgent"]): ...
def get_delegation_tools(self, agents: list["BaseAgent"]) -> None: ...
def get_output_converter(
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
): ...
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str,
) -> None: ...
def test_key():
def test_key() -> None:
agent = MockAgent(
role="test role",
goal="test goal",
backstory="test backstory",
)
hash = hashlib.md5("test role|test goal|test backstory".encode()).hexdigest()
hash = hashlib.md5(b"test role|test goal|test backstory").hexdigest()
assert agent.key == hash

View File

@@ -11,11 +11,10 @@ from crewai.agents.parser import CrewAgentParser
@pytest.fixture
def parser():
agent = MockAgent()
p = CrewAgentParser(agent)
return p
return CrewAgentParser(agent)
def test_valid_action_parsing_special_characters(parser):
def test_valid_action_parsing_special_characters(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: what's the temperature in SF?"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -23,7 +22,7 @@ def test_valid_action_parsing_special_characters(parser):
assert result.tool_input == "what's the temperature in SF?"
def test_valid_action_parsing_with_json_tool_input(parser):
def test_valid_action_parsing_with_json_tool_input(parser) -> None:
text = """
Thought: Let's find the information
Action: query
@@ -36,7 +35,7 @@ def test_valid_action_parsing_with_json_tool_input(parser):
assert result.tool_input == expected_tool_input
def test_valid_action_parsing_with_quotes(parser):
def test_valid_action_parsing_with_quotes(parser) -> None:
text = 'Thought: Let\'s find the temperature\nAction: search\nAction Input: "temperature in SF"'
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -44,7 +43,7 @@ def test_valid_action_parsing_with_quotes(parser):
assert result.tool_input == "temperature in SF"
def test_valid_action_parsing_with_curly_braces(parser):
def test_valid_action_parsing_with_curly_braces(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: {temperature in SF}"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -52,7 +51,7 @@ def test_valid_action_parsing_with_curly_braces(parser):
assert result.tool_input == "{temperature in SF}"
def test_valid_action_parsing_with_angle_brackets(parser):
def test_valid_action_parsing_with_angle_brackets(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: <temperature in SF>"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -60,7 +59,7 @@ def test_valid_action_parsing_with_angle_brackets(parser):
assert result.tool_input == "<temperature in SF>"
def test_valid_action_parsing_with_parentheses(parser):
def test_valid_action_parsing_with_parentheses(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: (temperature in SF)"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -68,7 +67,7 @@ def test_valid_action_parsing_with_parentheses(parser):
assert result.tool_input == "(temperature in SF)"
def test_valid_action_parsing_with_mixed_brackets(parser):
def test_valid_action_parsing_with_mixed_brackets(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: [temperature in {SF}]"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -76,7 +75,7 @@ def test_valid_action_parsing_with_mixed_brackets(parser):
assert result.tool_input == "[temperature in {SF}]"
def test_valid_action_parsing_with_nested_quotes(parser):
def test_valid_action_parsing_with_nested_quotes(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: \"what's the temperature in 'SF'?\""
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -84,7 +83,7 @@ def test_valid_action_parsing_with_nested_quotes(parser):
assert result.tool_input == "what's the temperature in 'SF'?"
def test_valid_action_parsing_with_incomplete_json(parser):
def test_valid_action_parsing_with_incomplete_json(parser) -> None:
text = 'Thought: Let\'s find the temperature\nAction: search\nAction Input: {"query": "temperature in SF"'
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -92,7 +91,7 @@ def test_valid_action_parsing_with_incomplete_json(parser):
assert result.tool_input == '{"query": "temperature in SF"}'
def test_valid_action_parsing_with_special_characters(parser):
def test_valid_action_parsing_with_special_characters(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: what is the temperature in SF? @$%^&*"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -100,7 +99,7 @@ def test_valid_action_parsing_with_special_characters(parser):
assert result.tool_input == "what is the temperature in SF? @$%^&*"
def test_valid_action_parsing_with_combination(parser):
def test_valid_action_parsing_with_combination(parser) -> None:
text = 'Thought: Let\'s find the temperature\nAction: search\nAction Input: "[what is the temperature in SF?]"'
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -108,7 +107,7 @@ def test_valid_action_parsing_with_combination(parser):
assert result.tool_input == "[what is the temperature in SF?]"
def test_valid_action_parsing_with_mixed_quotes(parser):
def test_valid_action_parsing_with_mixed_quotes(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: \"what's the temperature in SF?\""
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -116,7 +115,7 @@ def test_valid_action_parsing_with_mixed_quotes(parser):
assert result.tool_input == "what's the temperature in SF?"
def test_valid_action_parsing_with_newlines(parser):
def test_valid_action_parsing_with_newlines(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: what is\nthe temperature in SF?"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -124,7 +123,7 @@ def test_valid_action_parsing_with_newlines(parser):
assert result.tool_input == "what is\nthe temperature in SF?"
def test_valid_action_parsing_with_escaped_characters(parser):
def test_valid_action_parsing_with_escaped_characters(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: what is the temperature in SF? \\n"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -132,7 +131,7 @@ def test_valid_action_parsing_with_escaped_characters(parser):
assert result.tool_input == "what is the temperature in SF? \\n"
def test_valid_action_parsing_with_json_string(parser):
def test_valid_action_parsing_with_json_string(parser) -> None:
text = 'Thought: Let\'s find the temperature\nAction: search\nAction Input: {"query": "temperature in SF"}'
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -140,7 +139,7 @@ def test_valid_action_parsing_with_json_string(parser):
assert result.tool_input == '{"query": "temperature in SF"}'
def test_valid_action_parsing_with_unbalanced_quotes(parser):
def test_valid_action_parsing_with_unbalanced_quotes(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search\nAction Input: \"what is the temperature in SF?"
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -148,61 +147,61 @@ def test_valid_action_parsing_with_unbalanced_quotes(parser):
assert result.tool_input == "what is the temperature in SF?"
def test_clean_action_no_formatting(parser):
def test_clean_action_no_formatting(parser) -> None:
action = "Ask question to senior researcher"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_leading_asterisks(parser):
def test_clean_action_with_leading_asterisks(parser) -> None:
action = "** Ask question to senior researcher"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_trailing_asterisks(parser):
def test_clean_action_with_trailing_asterisks(parser) -> None:
action = "Ask question to senior researcher **"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_leading_and_trailing_asterisks(parser):
def test_clean_action_with_leading_and_trailing_asterisks(parser) -> None:
action = "** Ask question to senior researcher **"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_multiple_leading_asterisks(parser):
def test_clean_action_with_multiple_leading_asterisks(parser) -> None:
action = "**** Ask question to senior researcher"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_multiple_trailing_asterisks(parser):
def test_clean_action_with_multiple_trailing_asterisks(parser) -> None:
action = "Ask question to senior researcher ****"
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_spaces_and_asterisks(parser):
def test_clean_action_with_spaces_and_asterisks(parser) -> None:
action = " ** Ask question to senior researcher ** "
cleaned_action = parser._clean_action(action)
assert cleaned_action == "Ask question to senior researcher"
def test_clean_action_with_only_asterisks(parser):
def test_clean_action_with_only_asterisks(parser) -> None:
action = "****"
cleaned_action = parser._clean_action(action)
assert cleaned_action == ""
def test_clean_action_with_empty_string(parser):
def test_clean_action_with_empty_string(parser) -> None:
action = ""
cleaned_action = parser._clean_action(action)
assert cleaned_action == ""
def test_valid_final_answer_parsing(parser):
def test_valid_final_answer_parsing(parser) -> None:
text = (
"Thought: I found the information\nFinal Answer: The temperature is 100 degrees"
)
@@ -211,36 +210,36 @@ def test_valid_final_answer_parsing(parser):
assert result.output == "The temperature is 100 degrees"
def test_missing_action_error(parser):
def test_missing_action_error(parser) -> None:
text = "Thought: Let's find the temperature\nAction Input: what is the temperature in SF?"
with pytest.raises(OutputParserException) as exc_info:
parser.parse(text)
assert "Invalid Format: I missed the 'Action:' after 'Thought:'." in str(
exc_info.value
exc_info.value,
)
def test_missing_action_input_error(parser):
def test_missing_action_input_error(parser) -> None:
text = "Thought: Let's find the temperature\nAction: search"
with pytest.raises(OutputParserException) as exc_info:
parser.parse(text)
assert "I missed the 'Action Input:' after 'Action:'." in str(exc_info.value)
def test_safe_repair_json(parser):
def test_safe_repair_json(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": Senior Researcher'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_unrepairable(parser):
def test_safe_repair_json_unrepairable(parser) -> None:
invalid_json = "{invalid_json"
result = parser._safe_repair_json(invalid_json)
assert result == invalid_json # Should return the original if unrepairable
def test_safe_repair_json_missing_quotes(parser):
def test_safe_repair_json_missing_quotes(parser) -> None:
invalid_json = (
'{task: "Research XAI", context: "Explainable AI", coworker: Senior Researcher}'
)
@@ -249,77 +248,77 @@ def test_safe_repair_json_missing_quotes(parser):
assert result == expected_repaired_json
def test_safe_repair_json_unclosed_brackets(parser):
def test_safe_repair_json_unclosed_brackets(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_extra_commas(parser):
def test_safe_repair_json_extra_commas(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher",}'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_trailing_commas(parser):
def test_safe_repair_json_trailing_commas(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher",}'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_single_quotes(parser):
def test_safe_repair_json_single_quotes(parser) -> None:
invalid_json = "{'task': 'Research XAI', 'context': 'Explainable AI', 'coworker': 'Senior Researcher'}"
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_mixed_quotes(parser):
def test_safe_repair_json_mixed_quotes(parser) -> None:
invalid_json = "{'task': \"Research XAI\", 'context': \"Explainable AI\", 'coworker': 'Senior Researcher'}"
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_unescaped_characters(parser):
def test_safe_repair_json_unescaped_characters(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher\n"}'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_missing_colon(parser):
def test_safe_repair_json_missing_colon(parser) -> None:
invalid_json = '{"task" "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_missing_comma(parser):
def test_safe_repair_json_missing_comma(parser) -> None:
invalid_json = '{"task": "Research XAI" "context": "Explainable AI", "coworker": "Senior Researcher"}'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_unexpected_trailing_characters(parser):
def test_safe_repair_json_unexpected_trailing_characters(parser) -> None:
invalid_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"} random text'
expected_repaired_json = '{"task": "Research XAI", "context": "Explainable AI", "coworker": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_safe_repair_json_special_characters_key(parser):
def test_safe_repair_json_special_characters_key(parser) -> None:
invalid_json = '{"task!@#": "Research XAI", "context$%^": "Explainable AI", "coworker&*()": "Senior Researcher"}'
expected_repaired_json = '{"task!@#": "Research XAI", "context$%^": "Explainable AI", "coworker&*()": "Senior Researcher"}'
result = parser._safe_repair_json(invalid_json)
assert result == expected_repaired_json
def test_parsing_with_whitespace(parser):
def test_parsing_with_whitespace(parser) -> None:
text = " Thought: Let's find the temperature \n Action: search \n Action Input: what is the temperature in SF? "
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -327,7 +326,7 @@ def test_parsing_with_whitespace(parser):
assert result.tool_input == "what is the temperature in SF?"
def test_parsing_with_special_characters(parser):
def test_parsing_with_special_characters(parser) -> None:
text = 'Thought: Let\'s find the temperature\nAction: search\nAction Input: "what is the temperature in SF?"'
result = parser.parse(text)
assert isinstance(result, AgentAction)
@@ -335,7 +334,7 @@ def test_parsing_with_special_characters(parser):
assert result.tool_input == "what is the temperature in SF?"
def test_integration_valid_and_invalid(parser):
def test_integration_valid_and_invalid(parser) -> None:
text = """
Thought: Let's find the temperature
Action: search
@@ -366,7 +365,7 @@ def test_integration_valid_and_invalid(parser):
class MockAgent:
def increment_formatting_errors(self):
def increment_formatting_errors(self) -> None:
pass

View File

@@ -1,17 +1,18 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
import requests
from crewai.cli.authentication.main import AuthenticationCommand
class TestAuthenticationCommand(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.auth_command = AuthenticationCommand()
@patch("crewai.cli.authentication.main.requests.post")
def test_get_device_code(self, mock_post):
def test_get_device_code(self, mock_post) -> None:
mock_response = MagicMock()
mock_response.json.return_value = {
"device_code": "123456",
@@ -23,16 +24,14 @@ class TestAuthenticationCommand(unittest.TestCase):
device_code_data = self.auth_command._get_device_code()
self.assertEqual(device_code_data["device_code"], "123456")
self.assertEqual(device_code_data["user_code"], "ABCDEF")
self.assertEqual(
device_code_data["verification_uri_complete"], "https://example.com"
)
self.assertEqual(device_code_data["interval"], 5)
assert device_code_data["device_code"] == "123456"
assert device_code_data["user_code"] == "ABCDEF"
assert device_code_data["verification_uri_complete"] == "https://example.com"
assert device_code_data["interval"] == 5
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.cli.authentication.main.webbrowser.open")
def test_display_auth_instructions(self, mock_open, mock_print):
def test_display_auth_instructions(self, mock_open, mock_print) -> None:
device_code_data = {
"verification_uri_complete": "https://example.com",
"user_code": "ABCDEF",
@@ -49,8 +48,8 @@ class TestAuthenticationCommand(unittest.TestCase):
@patch("crewai.cli.authentication.main.validate_token")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_success(
self, mock_print, mock_validate_token, mock_post, mock_tool
):
self, mock_print, mock_validate_token, mock_post, mock_tool,
) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
@@ -66,12 +65,12 @@ class TestAuthenticationCommand(unittest.TestCase):
mock_validate_token.assert_called_once_with("TOKEN")
mock_print.assert_called_once_with(
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n"
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
)
@patch("crewai.cli.authentication.main.requests.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_error(self, mock_print, mock_post):
def test_poll_for_token_error(self, mock_print, mock_post) -> None:
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {
@@ -80,14 +79,14 @@ class TestAuthenticationCommand(unittest.TestCase):
}
mock_post.return_value = mock_response
with self.assertRaises(requests.HTTPError):
with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token({"device_code": "123456"})
mock_print.assert_not_called()
@patch("crewai.cli.authentication.main.requests.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_print, mock_post):
def test_poll_for_token_timeout(self, mock_print, mock_post) -> None:
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {
@@ -99,5 +98,5 @@ class TestAuthenticationCommand(unittest.TestCase):
self.auth_command._poll_for_token({"device_code": "123456", "interval": 0.01})
mock_print.assert_called_once_with(
"Timeout: Failed to get the token. Please try again.", style="bold red"
"Timeout: Failed to get the token. Please try again.", style="bold red",
)

View File

@@ -11,14 +11,14 @@ from crewai.cli.authentication.utils import TokenManager, validate_token
class TestValidateToken(unittest.TestCase):
@patch("crewai.cli.authentication.utils.AsymmetricSignatureVerifier")
@patch("crewai.cli.authentication.utils.TokenVerifier")
def test_validate_token(self, mock_token_verifier, mock_asymmetric_verifier):
def test_validate_token(self, mock_token_verifier, mock_asymmetric_verifier) -> None:
mock_verifier_instance = mock_token_verifier.return_value
mock_id_token = "mock_id_token"
validate_token(mock_id_token)
mock_asymmetric_verifier.assert_called_once_with(
"https://crewai.us.auth0.com/.well-known/jwks.json"
"https://crewai.us.auth0.com/.well-known/jwks.json",
)
mock_token_verifier.assert_called_once_with(
signature_verifier=mock_asymmetric_verifier.return_value,
@@ -29,38 +29,38 @@ class TestValidateToken(unittest.TestCase):
class TestTokenManager(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.token_manager = TokenManager()
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read) -> None:
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
assert result == mock_key
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate) -> None:
mock_key = b"new_key"
mock_read.return_value = None
mock_generate.return_value = mock_key
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
assert result == mock_key
mock_read.assert_called_once_with("secret.key")
mock_generate.assert_called_once()
mock_save.assert_called_once_with("secret.key", mock_key)
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
def test_save_tokens(self, mock_save):
def test_save_tokens(self, mock_save) -> None:
access_token = "test_token"
expires_in = 3600
@@ -68,10 +68,10 @@ class TestTokenManager(unittest.TestCase):
mock_save.assert_called_once()
args = mock_save.call_args[0]
self.assertEqual(args[0], "tokens.enc")
assert args[0] == "tokens.enc"
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
assert data["access_token"] == access_token
expiration = datetime.fromisoformat(data["expiration"])
self.assertAlmostEqual(
expiration,
@@ -80,7 +80,7 @@ class TestTokenManager(unittest.TestCase):
)
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
def test_get_token_valid(self, mock_read):
def test_get_token_valid(self, mock_read) -> None:
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
@@ -89,10 +89,10 @@ class TestTokenManager(unittest.TestCase):
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
assert result == access_token
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
def test_get_token_expired(self, mock_read):
def test_get_token_expired(self, mock_read) -> None:
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
@@ -101,12 +101,12 @@ class TestTokenManager(unittest.TestCase):
result = self.token_manager.get_token()
self.assertIsNone(result)
assert result is None
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
@patch("crewai.cli.authentication.utils.os.chmod")
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path) -> None:
mock_path = MagicMock()
mock_get_path.return_value = mock_path
filename = "test_file.txt"
@@ -121,9 +121,9 @@ class TestTokenManager(unittest.TestCase):
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
@patch(
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content",
)
def test_read_secure_file_exists(self, mock_open, mock_get_path):
def test_read_secure_file_exists(self, mock_open, mock_get_path) -> None:
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = True
@@ -131,12 +131,12 @@ class TestTokenManager(unittest.TestCase):
result = self.token_manager.read_secure_file(filename)
self.assertEqual(result, b"test_content")
assert result == b"test_content"
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
def test_read_secure_file_not_exists(self, mock_get_path):
def test_read_secure_file_not_exists(self, mock_get_path) -> None:
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = False
@@ -144,5 +144,5 @@ class TestTokenManager(unittest.TestCase):
result = self.token_manager.read_secure_file(filename)
self.assertIsNone(result)
assert result is None
mock_path.__truediv__.assert_called_once_with(filename)

View File

@@ -27,7 +27,7 @@ def runner():
@mock.patch("crewai.cli.cli.train_crew")
def test_train_default_iterations(train_crew, runner):
def test_train_default_iterations(train_crew, runner) -> None:
result = runner.invoke(train)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
@@ -36,7 +36,7 @@ def test_train_default_iterations(train_crew, runner):
@mock.patch("crewai.cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
def test_train_custom_iterations(train_crew, runner) -> None:
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
@@ -45,7 +45,7 @@ def test_train_custom_iterations(train_crew, runner):
@mock.patch("crewai.cli.cli.train_crew")
def test_train_invalid_string_iterations(train_crew, runner):
def test_train_invalid_string_iterations(train_crew, runner) -> None:
result = runner.invoke(train, ["--n_iterations", "invalid"])
train_crew.assert_not_called()
@@ -66,12 +66,12 @@ def mock_crew():
@pytest.fixture
def mock_get_crews(mock_crew):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew],
) as mock_get_crew:
yield mock_get_crew
def test_reset_all_memories(mock_get_crews, runner):
def test_reset_all_memories(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-a"])
call_count = 0
@@ -86,7 +86,7 @@ def test_reset_all_memories(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_short_term_memories(mock_get_crews, runner):
def test_reset_short_term_memories(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-s"])
call_count = 0
for crew in mock_get_crews.return_value:
@@ -99,7 +99,7 @@ def test_reset_short_term_memories(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_entity_memories(mock_get_crews, runner):
def test_reset_entity_memories(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-e"])
call_count = 0
for crew in mock_get_crews.return_value:
@@ -110,7 +110,7 @@ def test_reset_entity_memories(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_long_term_memories(mock_get_crews, runner):
def test_reset_long_term_memories(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-l"])
call_count = 0
for crew in mock_get_crews.return_value:
@@ -121,7 +121,7 @@ def test_reset_long_term_memories(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_kickoff_outputs(mock_get_crews, runner):
def test_reset_kickoff_outputs(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-k"])
call_count = 0
for crew in mock_get_crews.return_value:
@@ -135,12 +135,12 @@ def test_reset_kickoff_outputs(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_multiple_memory_flags(mock_get_crews, runner):
def test_reset_multiple_memory_flags(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["-s", "-l"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_has_calls(
[mock.call(command_type="long"), mock.call(command_type="short")]
[mock.call(command_type="long"), mock.call(command_type="short")],
)
assert (
f"[Crew ({crew.name})] Long term memory has been reset.\n"
@@ -151,7 +151,7 @@ def test_reset_multiple_memory_flags(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_knowledge(mock_get_crews, runner):
def test_reset_knowledge(mock_get_crews, runner) -> None:
result = runner.invoke(reset_memories, ["--knowledge"])
call_count = 0
for crew in mock_get_crews.return_value:
@@ -162,7 +162,7 @@ def test_reset_knowledge(mock_get_crews, runner):
assert call_count == 1, "reset_memories should have been called once"
def test_reset_memory_from_many_crews(mock_get_crews, runner):
def test_reset_memory_from_many_crews(mock_get_crews, runner) -> None:
crews = []
for crew_id in ["id-1234", "id-5678"]:
@@ -185,7 +185,7 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
assert call_count == 2, "reset_memories should have been called twice"
def test_reset_no_memory_flags(runner):
def test_reset_no_memory_flags(runner) -> None:
result = runner.invoke(
reset_memories,
)
@@ -195,21 +195,21 @@ def test_reset_no_memory_flags(runner):
)
def test_version_flag(runner):
def test_version_flag(runner) -> None:
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command(runner):
def test_version_command(runner) -> None:
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command_with_tools(runner):
def test_version_command_with_tools(runner) -> None:
result = runner.invoke(version, ["--tools"])
assert result.exit_code == 0
@@ -221,7 +221,7 @@ def test_version_command_with_tools(runner):
@mock.patch("crewai.cli.cli.evaluate_crew")
def test_test_default_iterations(evaluate_crew, runner):
def test_test_default_iterations(evaluate_crew, runner) -> None:
result = runner.invoke(test)
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
@@ -230,7 +230,7 @@ def test_test_default_iterations(evaluate_crew, runner):
@mock.patch("crewai.cli.cli.evaluate_crew")
def test_test_custom_iterations(evaluate_crew, runner):
def test_test_custom_iterations(evaluate_crew, runner) -> None:
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
evaluate_crew.assert_called_once_with(5, "gpt-4o")
@@ -239,7 +239,7 @@ def test_test_custom_iterations(evaluate_crew, runner):
@mock.patch("crewai.cli.cli.evaluate_crew")
def test_test_invalid_string_iterations(evaluate_crew, runner):
def test_test_invalid_string_iterations(evaluate_crew, runner) -> None:
result = runner.invoke(test, ["--n_iterations", "invalid"])
evaluate_crew.assert_not_called()
@@ -251,7 +251,7 @@ def test_test_invalid_string_iterations(evaluate_crew, runner):
@mock.patch("crewai.cli.cli.AuthenticationCommand")
def test_signup(command, runner):
def test_signup(command, runner) -> None:
mock_auth = command.return_value
result = runner.invoke(signup)
@@ -260,7 +260,7 @@ def test_signup(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_create(command, runner):
def test_deploy_create(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deploy_create)
@@ -269,7 +269,7 @@ def test_deploy_create(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_list(command, runner):
def test_deploy_list(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deploy_list)
@@ -278,7 +278,7 @@ def test_deploy_list(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_push(command, runner):
def test_deploy_push(command, runner) -> None:
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_push, ["-u", uuid])
@@ -288,7 +288,7 @@ def test_deploy_push(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_push_no_uuid(command, runner):
def test_deploy_push_no_uuid(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deploy_push)
@@ -297,7 +297,7 @@ def test_deploy_push_no_uuid(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_status(command, runner):
def test_deploy_status(command, runner) -> None:
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deply_status, ["-u", uuid])
@@ -307,7 +307,7 @@ def test_deploy_status(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_status_no_uuid(command, runner):
def test_deploy_status_no_uuid(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deply_status)
@@ -316,7 +316,7 @@ def test_deploy_status_no_uuid(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_logs(command, runner):
def test_deploy_logs(command, runner) -> None:
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_logs, ["-u", uuid])
@@ -326,7 +326,7 @@ def test_deploy_logs(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_logs_no_uuid(command, runner):
def test_deploy_logs_no_uuid(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deploy_logs)
@@ -335,7 +335,7 @@ def test_deploy_logs_no_uuid(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_remove(command, runner):
def test_deploy_remove(command, runner) -> None:
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_remove, ["-u", uuid])
@@ -345,7 +345,7 @@ def test_deploy_remove(command, runner):
@mock.patch("crewai.cli.cli.DeployCommand")
def test_deploy_remove_no_uuid(command, runner):
def test_deploy_remove_no_uuid(command, runner) -> None:
mock_deploy = command.return_value
result = runner.invoke(deploy_remove)
@@ -355,12 +355,11 @@ def test_deploy_remove_no_uuid(command, runner):
@mock.patch("crewai.cli.add_crew_to_flow.create_embedded_crew")
@mock.patch("pathlib.Path.exists", return_value=True) # Mock the existence check
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner) -> None:
crew_name = "new_crew"
result = runner.invoke(flow_add_crew, [crew_name])
# Log the output for debugging
print(result.output)
assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert f"Adding crew {crew_name} to the flow" in result.output
@@ -373,11 +372,11 @@ def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
assert isinstance(call_kwargs["parent_folder"], Path)
def test_add_crew_to_flow_not_in_root(runner):
def test_add_crew_to_flow_not_in_root(runner) -> None:
# Simulate not being in the root of a flow project
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
# Mock Path.exists to return False when checking for pyproject.toml
def exists_side_effect(self):
def exists_side_effect(self) -> bool:
if self.name == "pyproject.toml":
return False # Simulate that pyproject.toml does not exist
return True # All other paths exist
@@ -388,5 +387,5 @@ def test_add_crew_to_flow_not_in_root(runner):
assert result.exit_code != 0
assert "This command must be run from the root of a flow project." in str(
result.output
result.output,
)

View File

@@ -8,34 +8,34 @@ from crewai.cli.config import Settings
class TestSettings(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
def tearDown(self):
def tearDown(self) -> None:
shutil.rmtree(self.test_dir)
def test_empty_initialization(self):
def test_empty_initialization(self) -> None:
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
self.assertIsNone(settings.tool_repository_password)
assert settings.tool_repository_username is None
assert settings.tool_repository_password is None
def test_initialization_with_data(self):
def test_initialization_with_data(self) -> None:
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
config_path=self.config_path, tool_repository_username="user1",
)
self.assertEqual(settings.tool_repository_username, "user1")
self.assertIsNone(settings.tool_repository_password)
assert settings.tool_repository_username == "user1"
assert settings.tool_repository_password is None
def test_initialization_with_existing_file(self):
def test_initialization_with_existing_file(self) -> None:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"tool_repository_username": "file_user"}, f)
settings = Settings(config_path=self.config_path)
self.assertEqual(settings.tool_repository_username, "file_user")
assert settings.tool_repository_username == "file_user"
def test_merge_file_and_input_data(self):
def test_merge_file_and_input_data(self) -> None:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump(
@@ -47,61 +47,61 @@ class TestSettings(unittest.TestCase):
)
settings = Settings(
config_path=self.config_path, tool_repository_username="new_user"
config_path=self.config_path, tool_repository_username="new_user",
)
self.assertEqual(settings.tool_repository_username, "new_user")
self.assertEqual(settings.tool_repository_password, "file_pass")
assert settings.tool_repository_username == "new_user"
assert settings.tool_repository_password == "file_pass"
def test_dump_new_settings(self):
def test_dump_new_settings(self) -> None:
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
config_path=self.config_path, tool_repository_username="user1",
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["tool_repository_username"], "user1")
assert saved_data["tool_repository_username"] == "user1"
def test_update_existing_settings(self):
def test_update_existing_settings(self) -> None:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"existing_setting": "value"}, f)
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
config_path=self.config_path, tool_repository_username="user1",
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["existing_setting"], "value")
self.assertEqual(saved_data["tool_repository_username"], "user1")
assert saved_data["existing_setting"] == "value"
assert saved_data["tool_repository_username"] == "user1"
def test_none_values(self):
def test_none_values(self) -> None:
settings = Settings(config_path=self.config_path, tool_repository_username=None)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertIsNone(saved_data.get("tool_repository_username"))
assert saved_data.get("tool_repository_username") is None
def test_invalid_json_in_config(self):
def test_invalid_json_in_config(self) -> None:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
f.write("invalid json")
try:
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
assert settings.tool_repository_username is None
except json.JSONDecodeError:
self.fail("Settings initialization should handle invalid JSON")
def test_empty_config_file(self):
def test_empty_config_file(self) -> None:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
self.config_path.touch()
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
assert settings.tool_repository_username is None

View File

@@ -15,7 +15,7 @@ class TestDeployCommand(unittest.TestCase):
@patch("crewai.cli.command.get_auth_token")
@patch("crewai.cli.deploy.main.get_project_name")
@patch("crewai.cli.command.PlusAPI")
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token) -> None:
self.mock_get_auth_token = mock_get_auth_token
self.mock_get_project_name = mock_get_project_name
self.mock_plus_api = mock_plus_api
@@ -26,18 +26,18 @@ class TestDeployCommand(unittest.TestCase):
self.deploy_command = DeployCommand()
self.mock_client = self.deploy_command.plus_api_client
def test_init_success(self):
self.assertEqual(self.deploy_command.project_name, "test_project")
def test_init_success(self) -> None:
assert self.deploy_command.project_name == "test_project"
self.mock_plus_api.assert_called_once_with(api_key="test_token")
@patch("crewai.cli.command.get_auth_token")
def test_init_failure(self, mock_get_auth_token):
def test_init_failure(self, mock_get_auth_token) -> None:
mock_get_auth_token.side_effect = Exception("Auth failed")
with self.assertRaises(SystemExit):
with pytest.raises(SystemExit):
DeployCommand()
def test_validate_response_successful_response(self):
def test_validate_response_successful_response(self) -> None:
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"message": "Success"}
mock_response.status_code = 200
@@ -47,7 +47,7 @@ class TestDeployCommand(unittest.TestCase):
self.deploy_command._validate_response(mock_response)
assert fake_out.getvalue() == ""
def test_validate_response_json_decode_error(self):
def test_validate_response_json_decode_error(self) -> None:
mock_response = Mock(spec=requests.Response)
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
mock_response.status_code = 500
@@ -64,7 +64,7 @@ class TestDeployCommand(unittest.TestCase):
assert "Status Code: 500" in output
assert "Response:\nb'Invalid JSON'" in output
def test_validate_response_422_error(self):
def test_validate_response_422_error(self) -> None:
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {
"field1": ["Error message 1"],
@@ -84,7 +84,7 @@ class TestDeployCommand(unittest.TestCase):
assert "Field1 Error message 1" in output
assert "Field2 Error message 2" in output
def test_validate_response_other_error(self):
def test_validate_response_other_error(self) -> None:
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"error": "Something went wrong"}
mock_response.status_code = 500
@@ -97,29 +97,29 @@ class TestDeployCommand(unittest.TestCase):
assert "Request to Enterprise API failed. Details:" in output
assert "Details:\nSomething went wrong" in output
def test_standard_no_param_error_message(self):
def test_standard_no_param_error_message(self) -> None:
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._standard_no_param_error_message()
self.assertIn("No UUID provided", fake_out.getvalue())
assert "No UUID provided" in fake_out.getvalue()
def test_display_deployment_info(self):
def test_display_deployment_info(self) -> None:
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._display_deployment_info(
{"uuid": "test-uuid", "status": "deployed"}
{"uuid": "test-uuid", "status": "deployed"},
)
self.assertIn("Deploying the crew...", fake_out.getvalue())
self.assertIn("test-uuid", fake_out.getvalue())
self.assertIn("deployed", fake_out.getvalue())
assert "Deploying the crew..." in fake_out.getvalue()
assert "test-uuid" in fake_out.getvalue()
assert "deployed" in fake_out.getvalue()
def test_display_logs(self):
def test_display_logs(self) -> None:
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._display_logs(
[{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}]
[{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}],
)
self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue())
assert "2023-01-01 - INFO: Test log" in fake_out.getvalue()
@patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info")
def test_deploy_with_uuid(self, mock_display):
def test_deploy_with_uuid(self, mock_display) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"uuid": "test-uuid"}
@@ -131,7 +131,7 @@ class TestDeployCommand(unittest.TestCase):
mock_display.assert_called_once_with({"uuid": "test-uuid"})
@patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info")
def test_deploy_with_project_name(self, mock_display):
def test_deploy_with_project_name(self, mock_display) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"uuid": "test-uuid"}
@@ -145,7 +145,7 @@ class TestDeployCommand(unittest.TestCase):
@patch("crewai.cli.deploy.main.fetch_and_json_env_file")
@patch("crewai.cli.deploy.main.git.Repository.origin_url")
@patch("builtins.input")
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env):
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env) -> None:
mock_fetch_env.return_value = {"ENV_VAR": "value"}
mock_git_origin_url.return_value = "https://github.com/test/repo.git"
mock_input.return_value = ""
@@ -157,10 +157,10 @@ class TestDeployCommand(unittest.TestCase):
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.create_crew()
self.assertIn("Deployment created successfully!", fake_out.getvalue())
self.assertIn("new-uuid", fake_out.getvalue())
assert "Deployment created successfully!" in fake_out.getvalue()
assert "new-uuid" in fake_out.getvalue()
def test_list_crews(self):
def test_list_crews(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = [
@@ -171,10 +171,10 @@ class TestDeployCommand(unittest.TestCase):
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.list_crews()
self.assertIn("Crew1 (uuid1) active", fake_out.getvalue())
self.assertIn("Crew2 (uuid2) inactive", fake_out.getvalue())
assert "Crew1 (uuid1) active" in fake_out.getvalue()
assert "Crew2 (uuid2) inactive" in fake_out.getvalue()
def test_get_crew_status(self):
def test_get_crew_status(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"name": "InternalCrew", "status": "active"}
@@ -182,10 +182,10 @@ class TestDeployCommand(unittest.TestCase):
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_status()
self.assertIn("InternalCrew", fake_out.getvalue())
self.assertIn("active", fake_out.getvalue())
assert "InternalCrew" in fake_out.getvalue()
assert "active" in fake_out.getvalue()
def test_get_crew_logs(self):
def test_get_crew_logs(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = [
@@ -196,22 +196,20 @@ class TestDeployCommand(unittest.TestCase):
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_logs(None)
self.assertIn("2023-01-01 - INFO: Log1", fake_out.getvalue())
self.assertIn("2023-01-02 - ERROR: Log2", fake_out.getvalue())
assert "2023-01-01 - INFO: Log1" in fake_out.getvalue()
assert "2023-01-02 - ERROR: Log2" in fake_out.getvalue()
def test_remove_crew(self):
def test_remove_crew(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 204
self.mock_client.delete_crew_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.remove_crew(None)
self.assertIn(
"Crew 'test_project' removed successfully", fake_out.getvalue()
)
assert "Crew 'test_project' removed successfully" in fake_out.getvalue()
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
def test_parse_toml_python_311_plus(self):
def test_parse_toml_python_311_plus(self) -> None:
toml_content = """
[tool.poetry]
name = "test_project"
@@ -222,7 +220,7 @@ class TestDeployCommand(unittest.TestCase):
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
"""
parsed = parse_toml(toml_content)
self.assertEqual(parsed["tool"]["poetry"]["name"], "test_project")
assert parsed["tool"]["poetry"]["name"] == "test_project"
@patch(
"builtins.open",
@@ -235,12 +233,11 @@ class TestDeployCommand(unittest.TestCase):
dependencies = ["crewai"]
""",
)
def test_get_project_name_python_310(self, mock_open):
def test_get_project_name_python_310(self, mock_open) -> None:
from crewai.cli.utils import get_project_name
project_name = get_project_name()
print("project_name", project_name)
self.assertEqual(project_name, "test_project")
assert project_name == "test_project"
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
@patch(
@@ -254,13 +251,13 @@ class TestDeployCommand(unittest.TestCase):
dependencies = ["crewai"]
""",
)
def test_get_project_name_python_311_plus(self, mock_open):
def test_get_project_name_python_311_plus(self, mock_open) -> None:
from crewai.cli.utils import get_project_name
project_name = get_project_name()
self.assertEqual(project_name, "test_project")
assert project_name == "test_project"
def test_get_crewai_version(self):
def test_get_crewai_version(self) -> None:
from crewai.cli.version import get_crewai_version
assert isinstance(get_crewai_version(), str)

View File

@@ -1,14 +1,13 @@
import pytest
from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS
def test_huggingface_in_providers():
def test_huggingface_in_providers() -> None:
"""Test that Huggingface is in the PROVIDERS list."""
assert "huggingface" in PROVIDERS
def test_huggingface_env_vars():
def test_huggingface_env_vars() -> None:
"""Test that Huggingface environment variables are properly configured."""
assert "huggingface" in ENV_VARS
assert any(
@@ -17,7 +16,7 @@ def test_huggingface_env_vars():
)
def test_huggingface_models():
def test_huggingface_models() -> None:
"""Test that Huggingface models are properly configured."""
assert "huggingface" in MODELS
assert len(MODELS["huggingface"]) > 0

View File

@@ -7,7 +7,7 @@ from crewai.cli import evaluate_crew
@pytest.mark.parametrize(
"n_iterations,model",
("n_iterations", "model"),
[
(1, "gpt-4o"),
(5, "gpt-3.5-turbo"),
@@ -15,10 +15,10 @@ from crewai.cli import evaluate_crew
],
)
@mock.patch("crewai.cli.evaluate_crew.subprocess.run")
def test_crew_success(mock_subprocess_run, n_iterations, model):
def test_crew_success(mock_subprocess_run, n_iterations, model) -> None:
"""Test the crew function for successful execution."""
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=f"uv run test {n_iterations} {model}", returncode=0
args=f"uv run test {n_iterations} {model}", returncode=0,
)
result = evaluate_crew.evaluate_crew(n_iterations, model)
@@ -32,7 +32,7 @@ def test_crew_success(mock_subprocess_run, n_iterations, model):
@mock.patch("crewai.cli.evaluate_crew.click")
def test_test_crew_zero_iterations(click):
def test_test_crew_zero_iterations(click) -> None:
evaluate_crew.evaluate_crew(0, "gpt-4o")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
@@ -41,7 +41,7 @@ def test_test_crew_zero_iterations(click):
@mock.patch("crewai.cli.evaluate_crew.click")
def test_test_crew_negative_iterations(click):
def test_test_crew_negative_iterations(click) -> None:
evaluate_crew.evaluate_crew(-2, "gpt-4o")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
@@ -51,7 +51,7 @@ def test_test_crew_negative_iterations(click):
@mock.patch("crewai.cli.evaluate_crew.click")
@mock.patch("crewai.cli.evaluate_crew.subprocess.run")
def test_test_crew_called_process_error(mock_subprocess_run, click):
def test_test_crew_called_process_error(mock_subprocess_run, click) -> None:
n_iterations = 5
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
returncode=1,
@@ -74,13 +74,13 @@ def test_test_crew_called_process_error(mock_subprocess_run, click):
err=True,
),
mock.call.echo("Error", err=True),
]
],
)
@mock.patch("crewai.cli.evaluate_crew.click")
@mock.patch("crewai.cli.evaluate_crew.subprocess.run")
def test_test_crew_unexpected_exception(mock_subprocess_run, click):
def test_test_crew_unexpected_exception(mock_subprocess_run, click) -> None:
# Arrange
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
@@ -93,5 +93,5 @@ def test_test_crew_unexpected_exception(mock_subprocess_run, click):
check=True,
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True
"An unexpected error occurred: Unexpected error", err=True,
)

View File

@@ -3,7 +3,7 @@ import pytest
from crewai.cli.git import Repository
@pytest.fixture()
@pytest.fixture
def repository(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
@@ -11,7 +11,7 @@ def repository(fp):
return Repository(path=".")
def test_init_with_invalid_git_repo(fp):
def test_init_with_invalid_git_repo(fp) -> None:
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(
["git", "rev-parse", "--is-inside-work-tree"],
@@ -23,16 +23,16 @@ def test_init_with_invalid_git_repo(fp):
Repository(path="invalid/path")
def test_is_git_not_installed(fp):
def test_is_git_not_installed(fp) -> None:
fp.register(["git", "--version"], returncode=1)
with pytest.raises(
ValueError, match="Git is not installed or not found in your PATH."
ValueError, match="Git is not installed or not found in your PATH.",
):
Repository(path=".")
def test_status(fp, repository):
def test_status(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
@@ -40,7 +40,7 @@ def test_status(fp, repository):
assert repository.status() == "## main...origin/main [ahead 1]"
def test_has_uncommitted_changes(fp, repository):
def test_has_uncommitted_changes(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
@@ -48,7 +48,7 @@ def test_has_uncommitted_changes(fp, repository):
assert repository.has_uncommitted_changes() is True
def test_is_ahead_or_behind(fp, repository):
def test_is_ahead_or_behind(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
@@ -56,17 +56,17 @@ def test_is_ahead_or_behind(fp, repository):
assert repository.is_ahead_or_behind() is True
def test_is_synced_when_synced(fp, repository):
def test_is_synced_when_synced(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n",
)
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n",
)
assert repository.is_synced() is True
def test_is_synced_with_uncommitted_changes(fp, repository):
def test_is_synced_with_uncommitted_changes(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
@@ -74,7 +74,7 @@ def test_is_synced_with_uncommitted_changes(fp, repository):
assert repository.is_synced() is False
def test_is_synced_when_ahead_or_behind(fp, repository):
def test_is_synced_when_ahead_or_behind(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
@@ -86,7 +86,7 @@ def test_is_synced_when_ahead_or_behind(fp, repository):
assert repository.is_synced() is False
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository) -> None:
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
@@ -94,7 +94,7 @@ def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
assert repository.is_synced() is False
def test_origin_url(fp, repository):
def test_origin_url(fp, repository) -> None:
fp.register(
["git", "remote", "get-url", "origin"],
stdout="https://github.com/user/repo.git\n",

View File

@@ -6,43 +6,43 @@ from crewai.cli.plus_api import PlusAPI
class TestPlusAPI(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.api_key = "test_api_key"
self.api = PlusAPI(self.api_key)
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(self.api.headers["Content-Type"], "application/json")
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
def test_init(self) -> None:
assert self.api.api_key == self.api_key
assert self.api.headers["Authorization"] == f"Bearer {self.api_key}"
assert self.api.headers["Content-Type"] == "application/json"
assert "CrewAI-CLI/" in self.api.headers["User-Agent"]
assert self.api.headers["X-Crewai-Version"]
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
def test_login_to_tool_repository(self, mock_make_request) -> None:
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools/login"
"POST", "/crewai_plus/api/v1/tools/login",
)
self.assertEqual(response, mock_response)
assert response == mock_response
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
def test_get_tool(self, mock_make_request) -> None:
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
"GET", "/crewai_plus/api/v1/tools/test_tool_handle",
)
self.assertEqual(response, mock_response)
assert response == mock_response
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
def test_publish_tool(self, mock_make_request) -> None:
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
@@ -52,7 +52,7 @@ class TestPlusAPI(unittest.TestCase):
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
handle, public, version, description, encoded_file,
)
params = {
@@ -63,12 +63,12 @@ class TestPlusAPI(unittest.TestCase):
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
"POST", "/crewai_plus/api/v1/tools", json=params,
)
self.assertEqual(response, mock_response)
assert response == mock_response
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
def test_publish_tool_without_description(self, mock_make_request) -> None:
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
@@ -78,7 +78,7 @@ class TestPlusAPI(unittest.TestCase):
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
handle, public, version, description, encoded_file,
)
params = {
@@ -89,12 +89,12 @@ class TestPlusAPI(unittest.TestCase):
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
"POST", "/crewai_plus/api/v1/tools", json=params,
)
self.assertEqual(response, mock_response)
assert response == mock_response
@patch("crewai.cli.plus_api.requests.Session")
def test_make_request(self, mock_session):
def test_make_request(self, mock_session) -> None:
mock_response = MagicMock()
mock_session_instance = mock_session.return_value
@@ -104,94 +104,91 @@ class TestPlusAPI(unittest.TestCase):
mock_session.assert_called_once()
mock_session_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers,
)
mock_session_instance.trust_env = False
self.assertEqual(response, mock_response)
assert response == mock_response
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_deploy_by_name(self, mock_make_request):
def test_deploy_by_name(self, mock_make_request) -> None:
self.api.deploy_by_name("test_project")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_deploy_by_uuid(self, mock_make_request):
def test_deploy_by_uuid(self, mock_make_request) -> None:
self.api.deploy_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_name(self, mock_make_request):
def test_crew_status_by_name(self, mock_make_request) -> None:
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_uuid(self, mock_make_request):
def test_crew_status_by_uuid(self, mock_make_request) -> None:
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
"GET", "/crewai_plus/api/v1/crews/test_uuid/status",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_crew_by_name(self, mock_make_request):
def test_crew_by_name(self, mock_make_request) -> None:
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment",
)
self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_crew_by_uuid(self, mock_make_request):
def test_crew_by_uuid(self, mock_make_request) -> None:
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment",
)
self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_name(self, mock_make_request):
def test_delete_crew_by_name(self, mock_make_request) -> None:
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_uuid(self, mock_make_request):
def test_delete_crew_by_uuid(self, mock_make_request) -> None:
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
"DELETE", "/crewai_plus/api/v1/crews/test_uuid",
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_list_crews(self, mock_make_request):
def test_list_crews(self, mock_make_request) -> None:
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
def test_create_crew(self, mock_make_request) -> None:
payload = {"name": "test_crew"}
self.api.create_crew(payload)
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews", json=payload
"POST", "/crewai_plus/api/v1/crews", json=payload,
)
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
def test_custom_base_url(self):
def test_custom_base_url(self) -> None:
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url.com/api",
)
assert custom_api.base_url == "https://custom-url.com/api"

View File

@@ -23,18 +23,18 @@ def temp_tree():
shutil.rmtree(root_dir)
def create_file(path, content):
def create_file(path, content) -> None:
with open(path, "w") as f:
f.write(content)
def test_tree_find_and_replace_file_content(temp_tree):
def test_tree_find_and_replace_file_content(temp_tree) -> None:
utils.tree_find_and_replace(temp_tree, "world", "universe")
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
with open(os.path.join(temp_tree, "file1.txt")) as f:
assert f.read() == "Hello, universe!"
def test_tree_find_and_replace_file_name(temp_tree):
def test_tree_find_and_replace_file_name(temp_tree) -> None:
old_path = os.path.join(temp_tree, "file2.txt")
new_path = os.path.join(temp_tree, "file2_renamed.txt")
os.rename(old_path, new_path)
@@ -43,19 +43,19 @@ def test_tree_find_and_replace_file_name(temp_tree):
assert not os.path.exists(new_path)
def test_tree_find_and_replace_directory_name(temp_tree):
def test_tree_find_and_replace_directory_name(temp_tree) -> None:
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
def test_tree_find_and_replace_nested_content(temp_tree):
def test_tree_find_and_replace_nested_content(temp_tree) -> None:
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt")) as f:
assert f.read() == "Updated content"
def test_tree_find_and_replace_no_matches(temp_tree):
def test_tree_find_and_replace_no_matches(temp_tree) -> None:
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
assert set(os.listdir(temp_tree)) == {
"file1.txt",
@@ -65,7 +65,7 @@ def test_tree_find_and_replace_no_matches(temp_tree):
}
def test_tree_copy_full_structure(temp_tree):
def test_tree_copy_full_structure(temp_tree) -> None:
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
@@ -79,19 +79,19 @@ def test_tree_copy_full_structure(temp_tree):
shutil.rmtree(dest_dir)
def test_tree_copy_preserve_content(temp_tree):
def test_tree_copy_preserve_content(temp_tree) -> None:
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
with open(os.path.join(dest_dir, "file1.txt")) as f:
assert f.read() == "Hello, world!"
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt")) as f:
assert f.read() == "Nested content"
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_to_existing_directory(temp_tree):
def test_tree_copy_to_existing_directory(temp_tree) -> None:
dest_dir = tempfile.mkdtemp()
try:
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")

View File

@@ -33,7 +33,7 @@ def tool_command():
@patch("crewai.cli.tools.main.subprocess.run")
def test_create_success(mock_subprocess, capsys, tool_command):
def test_create_success(mock_subprocess, capsys, tool_command) -> None:
with in_temp_dir():
tool_command.create("test-tool")
output = capsys.readouterr().out
@@ -43,11 +43,11 @@ def test_create_success(mock_subprocess, capsys, tool_command):
assert os.path.isfile(os.path.join("test_tool", "README.md"))
assert os.path.isfile(os.path.join("test_tool", "pyproject.toml"))
assert os.path.isfile(
os.path.join("test_tool", "src", "test_tool", "__init__.py")
os.path.join("test_tool", "src", "test_tool", "__init__.py"),
)
assert os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py"))
with open(os.path.join("test_tool", "src", "test_tool", "tool.py"), "r") as f:
with open(os.path.join("test_tool", "src", "test_tool", "tool.py")) as f:
content = f.read()
assert "class TestTool" in content
@@ -56,7 +56,7 @@ def test_create_success(mock_subprocess, capsys, tool_command):
@patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command) -> None:
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {
@@ -87,7 +87,7 @@ def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_tool_not_found(mock_get, capsys, tool_command):
def test_install_tool_not_found(mock_get, capsys, tool_command) -> None:
mock_get_response = MagicMock()
mock_get_response.status_code = 404
mock_get.return_value = mock_get_response
@@ -101,7 +101,7 @@ def test_install_tool_not_found(mock_get, capsys, tool_command):
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_api_error(mock_get, capsys, tool_command):
def test_install_api_error(mock_get, capsys, tool_command) -> None:
mock_get_response = MagicMock()
mock_get_response.status_code = 500
mock_get.return_value = mock_get_response
@@ -115,7 +115,7 @@ def test_install_api_error(mock_get, capsys, tool_command):
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command) -> None:
with raises(SystemExit):
tool_command.publish(is_public=True)
@@ -145,7 +145,7 @@ def test_publish_when_not_in_sync_and_force(
mock_get_project_version,
mock_get_project_name,
tool_command,
):
) -> None:
mock_publish_response = MagicMock()
mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"}
@@ -193,7 +193,7 @@ def test_publish_success(
mock_get_project_version,
mock_get_project_name,
tool_command,
):
) -> None:
mock_publish_response = MagicMock()
mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"}
@@ -240,7 +240,7 @@ def test_publish_failure(
mock_get_project_name,
capsys,
tool_command,
):
) -> None:
mock_publish_response = MagicMock()
mock_publish_response.status_code = 422
mock_publish_response.json.return_value = {"name": ["is already taken"]}
@@ -276,7 +276,7 @@ def test_publish_api_error(
mock_get_project_name,
capsys,
tool_command,
):
) -> None:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}

View File

@@ -5,7 +5,7 @@ from crewai.cli.train_crew import train_crew
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_positive_iterations(mock_subprocess_run):
def test_train_crew_positive_iterations(mock_subprocess_run) -> None:
n_iterations = 5
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=["uv", "run", "train", str(n_iterations)],
@@ -25,7 +25,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_zero_iterations(click):
def test_train_crew_zero_iterations(click) -> None:
train_crew(0, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
@@ -34,7 +34,7 @@ def test_train_crew_zero_iterations(click):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_negative_iterations(click):
def test_train_crew_negative_iterations(click) -> None:
train_crew(-2, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
@@ -44,7 +44,7 @@ def test_train_crew_negative_iterations(click):
@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_called_process_error(mock_subprocess_run, click):
def test_train_crew_called_process_error(mock_subprocess_run, click) -> None:
n_iterations = 5
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
returncode=1,
@@ -67,13 +67,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
err=True,
),
mock.call.echo("Error", err=True),
]
],
)
@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
def test_train_crew_unexpected_exception(mock_subprocess_run, click) -> None:
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
train_crew(n_iterations, "trained_agents_data.pkl")
@@ -85,5 +85,5 @@ def test_train_crew_unexpected_exception(mock_subprocess_run, click):
check=True,
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True
"An unexpected error occurred: Unexpected error", err=True,
)

View File

@@ -19,8 +19,9 @@ def setup_test_environment():
# Validate that the directory was created successfully
if not storage_dir.exists() or not storage_dir.is_dir():
msg = f"Failed to create test storage directory: {storage_dir}"
raise RuntimeError(
f"Failed to create test storage directory: {storage_dir}"
msg,
)
# Verify directory permissions
@@ -29,9 +30,10 @@ def setup_test_environment():
test_file = storage_dir / ".permissions_test"
test_file.touch()
test_file.unlink()
except (OSError, IOError) as e:
except OSError as e:
msg = f"Test storage directory {storage_dir} is not writable: {e}"
raise RuntimeError(
f"Test storage directory {storage_dir} is not writable: {e}"
msg,
)
# Set environment variable to point to the test storage directory

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
from typing import Any
import pytest
@@ -15,11 +14,12 @@ class CustomLLM(BaseLLM):
that returns a predefined response for testing purposes.
"""
def __init__(self, response="Default response", model="test-model"):
def __init__(self, response="Default response", model="test-model") -> None:
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__(model=model)
self.response = response
@@ -32,8 +32,7 @@ class CustomLLM(BaseLLM):
callbacks=None,
available_functions=None,
):
"""
Mock LLM call that returns a predefined response.
"""Mock LLM call that returns a predefined response.
Properly formats messages to match OpenAI's expected structure.
"""
self.call_count += 1
@@ -57,6 +56,7 @@ class CustomLLM(BaseLLM):
Returns:
False, indicating that this LLM does not support function calling.
"""
return False
@@ -65,6 +65,7 @@ class CustomLLM(BaseLLM):
Returns:
False, indicating that this LLM does not support stop words.
"""
return False
@@ -73,12 +74,13 @@ class CustomLLM(BaseLLM):
Returns:
4096, a typical context window size for modern LLMs.
"""
return 4096
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_implementation():
def test_custom_llm_implementation() -> None:
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
@@ -89,7 +91,7 @@ def test_custom_llm_implementation():
# Test calling the custom LLM
response = result_llm.call(
"What is the answer to life, the universe, and everything?"
"What is the answer to life, the universe, and everything?",
)
# Verify that the response from the custom LLM was used
@@ -97,7 +99,7 @@ def test_custom_llm_implementation():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_within_crew():
def test_custom_llm_within_crew() -> None:
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
@@ -128,8 +130,8 @@ def test_custom_llm_within_crew():
assert "Hello!" in result.raw
def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages"""
def test_custom_llm_message_formatting() -> None:
"""Test that the custom LLM properly formats messages."""
custom_llm = CustomLLM(response="Test response", model="test-model")
# Test with string input
@@ -148,21 +150,22 @@ def test_custom_llm_message_formatting():
class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
def __init__(self, jwt_token: str) -> None:
super().__init__(model="test-model")
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token")
msg = "Invalid JWT token"
raise ValueError(msg)
self.jwt_token = jwt_token
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Record the call and return a predefined response."""
self.calls.append(
{
@@ -170,7 +173,7 @@ class JWTAuthLLM(BaseLLM):
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
},
)
# In a real implementation, this would use the JWT token to authenticate
# with an external service
@@ -189,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
return 8192
def test_custom_llm_with_jwt_auth():
def test_custom_llm_with_jwt_auth() -> None:
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
@@ -207,7 +210,7 @@ def test_custom_llm_with_jwt_auth():
assert response == "Response from JWT-authenticated LLM"
def test_jwt_auth_llm_validation():
def test_jwt_auth_llm_validation() -> None:
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
@@ -221,12 +224,13 @@ def test_jwt_auth_llm_validation():
class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
def __init__(self, max_retries: int = 3, timeout: int = 30) -> None:
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__(model="test-model")
self.max_retries = max_retries
@@ -237,11 +241,11 @@ class TimeoutHandlingLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Simulate API calls with timeout handling and retry logic.
Args:
@@ -255,6 +259,7 @@ class TimeoutHandlingLLM(BaseLLM):
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append(
@@ -264,7 +269,7 @@ class TimeoutHandlingLLM(BaseLLM):
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0,
}
},
)
# Simulate retry logic
@@ -276,46 +281,47 @@ class TimeoutHandlingLLM(BaseLLM):
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
msg = f"LLM request failed after {self.max_retries} attempts"
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
msg,
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Success on first attempt
return "First attempt response"
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
},
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
msg = f"LLM request failed after {self.max_retries} attempts"
raise TimeoutError(
msg,
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on retry
return "Response after retry"
return None
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
@@ -324,6 +330,7 @@ class TimeoutHandlingLLM(BaseLLM):
Returns:
True, indicating that this LLM supports stop words.
"""
return True
@@ -332,11 +339,12 @@ class TimeoutHandlingLLM(BaseLLM):
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
def test_timeout_handling_llm():
def test_timeout_handling_llm() -> None:
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()

View File

@@ -2,6 +2,7 @@
import asyncio
from datetime import datetime
from typing import NoReturn
import pytest
from pydantic import BaseModel
@@ -17,17 +18,17 @@ from crewai.utilities.events import (
from crewai.utilities.events.flow_events import FlowPlotEvent
def test_simple_sequential_flow():
def test_simple_sequential_flow() -> None:
"""Test a simple flow with two steps called sequentially."""
execution_order = []
class SimpleFlow(Flow):
@start()
def step_1(self):
def step_1(self) -> None:
execution_order.append("step_1")
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
execution_order.append("step_2")
flow = SimpleFlow()
@@ -36,25 +37,25 @@ def test_simple_sequential_flow():
assert execution_order == ["step_1", "step_2"]
def test_flow_with_multiple_starts():
def test_flow_with_multiple_starts() -> None:
"""Test a flow with multiple start methods."""
execution_order = []
class MultiStartFlow(Flow):
@start()
def step_a(self):
def step_a(self) -> None:
execution_order.append("step_a")
@start()
def step_b(self):
def step_b(self) -> None:
execution_order.append("step_b")
@listen(step_a)
def step_c(self):
def step_c(self) -> None:
execution_order.append("step_c")
@listen(step_b)
def step_d(self):
def step_d(self) -> None:
execution_order.append("step_d")
flow = MultiStartFlow()
@@ -68,7 +69,7 @@ def test_flow_with_multiple_starts():
assert execution_order.index("step_d") > execution_order.index("step_b")
def test_cyclic_flow():
def test_cyclic_flow() -> None:
"""Test a cyclic flow that runs a finite number of iterations."""
execution_order = []
@@ -77,17 +78,17 @@ def test_cyclic_flow():
max_iterations = 3
@start("loop")
def step_1(self):
def step_1(self) -> None:
if self.iteration >= self.max_iterations:
return # Do not proceed further
execution_order.append(f"step_1_{self.iteration}")
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
execution_order.append(f"step_2_{self.iteration}")
@router(step_2)
def step_3(self):
def step_3(self) -> str:
execution_order.append(f"step_3_{self.iteration}")
self.iteration += 1
if self.iteration < self.max_iterations:
@@ -105,21 +106,21 @@ def test_cyclic_flow():
assert execution_order == expected_order
def test_flow_with_and_condition():
def test_flow_with_and_condition() -> None:
"""Test a flow where a step waits for multiple other steps to complete."""
execution_order = []
class AndConditionFlow(Flow):
@start()
def step_1(self):
def step_1(self) -> None:
execution_order.append("step_1")
@start()
def step_2(self):
def step_2(self) -> None:
execution_order.append("step_2")
@listen(and_(step_1, step_2))
def step_3(self):
def step_3(self) -> None:
execution_order.append("step_3")
flow = AndConditionFlow()
@@ -132,21 +133,21 @@ def test_flow_with_and_condition():
assert execution_order.index("step_3") > execution_order.index("step_2")
def test_flow_with_or_condition():
def test_flow_with_or_condition() -> None:
"""Test a flow where a step is triggered when any of multiple steps complete."""
execution_order = []
class OrConditionFlow(Flow):
@start()
def step_a(self):
def step_a(self) -> None:
execution_order.append("step_a")
@start()
def step_b(self):
def step_b(self) -> None:
execution_order.append("step_b")
@listen(or_(step_a, step_b))
def step_c(self):
def step_c(self) -> None:
execution_order.append("step_c")
flow = OrConditionFlow()
@@ -155,32 +156,32 @@ def test_flow_with_or_condition():
assert "step_a" in execution_order or "step_b" in execution_order
assert "step_c" in execution_order
assert execution_order.index("step_c") > min(
execution_order.index("step_a"), execution_order.index("step_b")
execution_order.index("step_a"), execution_order.index("step_b"),
)
def test_flow_with_router():
def test_flow_with_router() -> None:
"""Test a flow that uses a router method to determine the next step."""
execution_order = []
class RouterFlow(Flow):
@start()
def start_method(self):
def start_method(self) -> None:
execution_order.append("start_method")
@router(start_method)
def router(self):
def router(self) -> str:
execution_order.append("router")
# Ensure the condition is set to True to follow the "step_if_true" path
condition = True
return "step_if_true" if condition else "step_if_false"
@listen("step_if_true")
def truthy(self):
def truthy(self) -> None:
execution_order.append("step_if_true")
@listen("step_if_false")
def falsy(self):
def falsy(self) -> None:
execution_order.append("step_if_false")
flow = RouterFlow()
@@ -189,18 +190,18 @@ def test_flow_with_router():
assert execution_order == ["start_method", "router", "step_if_true"]
def test_async_flow():
def test_async_flow() -> None:
"""Test an asynchronous flow."""
execution_order = []
class AsyncFlow(Flow):
@start()
async def step_1(self):
async def step_1(self) -> None:
execution_order.append("step_1")
await asyncio.sleep(0.1)
@listen(step_1)
async def step_2(self):
async def step_2(self) -> None:
execution_order.append("step_2")
await asyncio.sleep(0.1)
@@ -210,18 +211,19 @@ def test_async_flow():
assert execution_order == ["step_1", "step_2"]
def test_flow_with_exceptions():
def test_flow_with_exceptions() -> None:
"""Test flow behavior when exceptions occur in steps."""
execution_order = []
class ExceptionFlow(Flow):
@start()
def step_1(self):
def step_1(self) -> NoReturn:
execution_order.append("step_1")
raise ValueError("An error occurred in step_1")
msg = "An error occurred in step_1"
raise ValueError(msg)
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
execution_order.append("step_2")
flow = ExceptionFlow()
@@ -233,17 +235,17 @@ def test_flow_with_exceptions():
assert execution_order == ["step_1"]
def test_flow_restart():
def test_flow_restart() -> None:
"""Test restarting a flow after it has completed."""
execution_order = []
class RestartableFlow(Flow):
@start()
def step_1(self):
def step_1(self) -> None:
execution_order.append("step_1")
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
execution_order.append("step_2")
flow = RestartableFlow()
@@ -253,20 +255,20 @@ def test_flow_restart():
assert execution_order == ["step_1", "step_2", "step_1", "step_2"]
def test_flow_with_custom_state():
def test_flow_with_custom_state() -> None:
"""Test a flow that maintains and modifies internal state."""
class StateFlow(Flow):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.counter = 0
@start()
def step_1(self):
def step_1(self) -> None:
self.counter += 1
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
self.counter *= 2
assert self.counter == 2
@@ -275,13 +277,13 @@ def test_flow_with_custom_state():
assert flow.counter == 2
def test_flow_uuid_unstructured():
def test_flow_uuid_unstructured() -> None:
"""Test that unstructured (dictionary) flow states automatically get a UUID that persists."""
initial_id = None
class UUIDUnstructuredFlow(Flow):
@start()
def first_method(self):
def first_method(self) -> None:
nonlocal initial_id
# Verify ID is automatically generated
assert "id" in self.state
@@ -292,7 +294,7 @@ def test_flow_uuid_unstructured():
self.state["data"] = "example"
@listen(first_method)
def second_method(self):
def second_method(self) -> None:
# Ensure the ID persists after state updates
assert "id" in self.state
assert self.state["id"] == initial_id
@@ -308,7 +310,7 @@ def test_flow_uuid_unstructured():
assert len(flow.state["id"]) == 36
def test_flow_uuid_structured():
def test_flow_uuid_structured() -> None:
"""Test that structured (Pydantic) flow states automatically get a UUID that persists."""
initial_id = None
@@ -318,7 +320,7 @@ def test_flow_uuid_structured():
class UUIDStructuredFlow(Flow[MyStructuredState]):
@start()
def first_method(self):
def first_method(self) -> None:
nonlocal initial_id
# Verify ID is automatically generated and accessible as attribute
assert hasattr(self.state, "id")
@@ -330,7 +332,7 @@ def test_flow_uuid_structured():
self.state.message = "updated"
@listen(first_method)
def second_method(self):
def second_method(self) -> None:
# Ensure the ID persists after state updates
assert hasattr(self.state, "id")
assert self.state.id == initial_id
@@ -350,42 +352,41 @@ def test_flow_uuid_structured():
assert flow.state.message == "final"
def test_router_with_multiple_conditions():
def test_router_with_multiple_conditions() -> None:
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).
"""
execution_order = []
class ComplexRouterFlow(Flow):
@start()
def step_a(self):
def step_a(self) -> None:
execution_order.append("step_a")
@start()
def step_b(self):
def step_b(self) -> None:
execution_order.append("step_b")
@router(or_("step_a", "step_b"))
def router_or(self):
def router_or(self) -> str:
execution_order.append("router_or")
return "next_step_or"
@listen("next_step_or")
def handle_next_step_or_event(self):
def handle_next_step_or_event(self) -> None:
execution_order.append("handle_next_step_or_event")
@listen(handle_next_step_or_event)
def branch_2_step(self):
def branch_2_step(self) -> None:
execution_order.append("branch_2_step")
@router(and_(handle_next_step_or_event, branch_2_step))
def router_and(self):
def router_and(self) -> str:
execution_order.append("router_and")
return "final_step"
@listen("final_step")
def log_final_step(self):
def log_final_step(self) -> None:
execution_order.append("log_final_step")
flow = ComplexRouterFlow()
@@ -401,7 +402,7 @@ def test_router_with_multiple_conditions():
# Check that the AND router triggered after both relevant steps:
assert execution_order.index("router_and") > execution_order.index(
"handle_next_step_or_event"
"handle_next_step_or_event",
)
assert execution_order.index("router_and") > execution_order.index("branch_2_step")
@@ -409,23 +410,24 @@ def test_router_with_multiple_conditions():
assert execution_order.index("log_final_step") > execution_order.index("router_and")
def test_unstructured_flow_event_emission():
def test_unstructured_flow_event_emission() -> None:
"""Test that the correct events are emitted during unstructured flow
execution with all fields validated."""
execution with all fields validated.
"""
class PoemFlow(Flow):
@start()
def prepare_flower(self):
def prepare_flower(self) -> str:
self.state["flower"] = "roses"
return "foo"
@start()
def prepare_color(self):
def prepare_color(self) -> str:
self.state["color"] = "red"
return "bar"
@listen(prepare_color)
def write_first_sentence(self):
def write_first_sentence(self) -> str:
return f"{self.state['flower']} are {self.state['color']}"
@listen(write_first_sentence)
@@ -434,7 +436,7 @@ def test_unstructured_flow_event_emission():
return separator.join([first_sentence, "violets are blue"])
@listen(finish_poem)
def save_poem_to_database(self):
def save_poem_to_database(self) -> str:
# A method without args/kwargs to ensure events are sent correctly
return "roses are red\nviolets are blue"
@@ -442,15 +444,15 @@ def test_unstructured_flow_event_emission():
received_events = []
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
def handle_flow_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
def handle_method_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
def handle_flow_end(source, event) -> None:
received_events.append(event)
flow.kickoff(inputs={"separator": ", "})
@@ -473,7 +475,6 @@ def test_unstructured_flow_event_emission():
assert received_events[2].method_name == "prepare_color"
assert received_events[2].params == {}
print("received_events[2]", received_events[2])
assert "flower" in received_events[2].state
assert received_events[3].method_name == "write_first_sentence"
@@ -497,9 +498,10 @@ def test_unstructured_flow_event_emission():
assert isinstance(received_events[6].timestamp, datetime)
def test_structured_flow_event_emission():
def test_structured_flow_event_emission() -> None:
"""Test that the correct events are emitted during structured flow
execution with all fields validated."""
execution with all fields validated.
"""
class OnboardingState(BaseModel):
name: str = ""
@@ -507,11 +509,11 @@ def test_structured_flow_event_emission():
class OnboardingFlow(Flow[OnboardingState]):
@start()
def user_signs_up(self):
def user_signs_up(self) -> None:
self.state.sent = False
@listen(user_signs_up)
def send_welcome_message(self):
def send_welcome_message(self) -> str:
self.state.sent = True
return f"Welcome, {self.state.name}!"
@@ -521,19 +523,19 @@ def test_structured_flow_event_emission():
received_events = []
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
def handle_flow_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
def handle_method_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event):
def handle_method_end(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
def handle_flow_end(source, event) -> None:
received_events.append(event)
flow.kickoff(inputs={"name": "Anakin"})
@@ -552,11 +554,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False
assert received_events[3].state.sent is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].state.sent is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)
@@ -565,41 +567,42 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[5].timestamp, datetime)
def test_stateless_flow_event_emission():
def test_stateless_flow_event_emission() -> None:
"""Test that the correct events are emitted stateless during flow execution
with all fields validated."""
with all fields validated.
"""
class StatelessFlow(Flow):
@start()
def init(self):
def init(self) -> None:
pass
@listen(init)
def process(self):
def process(self) -> str:
return "Deeds will not be less valiant because they are unpraised."
event_log = []
def handle_event(_, event):
def handle_event(_, event) -> None:
event_log.append(event)
flow = StatelessFlow()
received_events = []
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
def handle_flow_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
def handle_method_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event):
def handle_method_end(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
def handle_flow_end(source, event) -> None:
received_events.append(event)
flow.kickoff()
@@ -630,14 +633,14 @@ def test_stateless_flow_event_emission():
assert isinstance(received_events[5].timestamp, datetime)
def test_flow_plotting():
def test_flow_plotting() -> None:
class StatelessFlow(Flow):
@start()
def init(self):
def init(self) -> str:
return "Initializing flow..."
@listen(init)
def process(self):
def process(self) -> str:
return "Deeds will not be less valiant because they are unpraised."
flow = StatelessFlow()
@@ -645,7 +648,7 @@ def test_flow_plotting():
received_events = []
@crewai_event_bus.on(FlowPlotEvent)
def handle_flow_plot(source, event):
def handle_flow_plot(source, event) -> None:
received_events.append(event)
flow.plot("test_flow")
@@ -656,59 +659,59 @@ def test_flow_plotting():
assert isinstance(received_events[0].timestamp, datetime)
def test_multiple_routers_from_same_trigger():
def test_multiple_routers_from_same_trigger() -> None:
"""Test that multiple routers triggered by the same method all activate their listeners."""
execution_order = []
class MultiRouterFlow(Flow):
def __init__(self):
def __init__(self) -> None:
super().__init__()
# Set diagnosed conditions to trigger all routers
self.state["diagnosed_conditions"] = "DHA" # Contains D, H, and A
@start()
def scan_medical(self):
def scan_medical(self) -> str:
execution_order.append("scan_medical")
return "scan_complete"
@router(scan_medical)
def diagnose_conditions(self):
def diagnose_conditions(self) -> str:
execution_order.append("diagnose_conditions")
return "diagnosis_complete"
@router(diagnose_conditions)
def diabetes_router(self):
def diabetes_router(self) -> str | None:
execution_order.append("diabetes_router")
if "D" in self.state["diagnosed_conditions"]:
return "diabetes"
return None
@listen("diabetes")
def diabetes_analysis(self):
def diabetes_analysis(self) -> str:
execution_order.append("diabetes_analysis")
return "diabetes_analysis_complete"
@router(diagnose_conditions)
def hypertension_router(self):
def hypertension_router(self) -> str | None:
execution_order.append("hypertension_router")
if "H" in self.state["diagnosed_conditions"]:
return "hypertension"
return None
@listen("hypertension")
def hypertension_analysis(self):
def hypertension_analysis(self) -> str:
execution_order.append("hypertension_analysis")
return "hypertension_analysis_complete"
@router(diagnose_conditions)
def anemia_router(self):
def anemia_router(self) -> str | None:
execution_order.append("anemia_router")
if "A" in self.state["diagnosed_conditions"]:
return "anemia"
return None
@listen("anemia")
def anemia_analysis(self):
def anemia_analysis(self) -> str:
execution_order.append("anemia_analysis")
return "anemia_analysis_complete"
@@ -731,27 +734,27 @@ def test_multiple_routers_from_same_trigger():
# Verify execution order constraints
assert execution_order.index("diagnose_conditions") > execution_order.index(
"scan_medical"
"scan_medical",
)
# All routers should execute after diagnose_conditions
assert execution_order.index("diabetes_router") > execution_order.index(
"diagnose_conditions"
"diagnose_conditions",
)
assert execution_order.index("hypertension_router") > execution_order.index(
"diagnose_conditions"
"diagnose_conditions",
)
assert execution_order.index("anemia_router") > execution_order.index(
"diagnose_conditions"
"diagnose_conditions",
)
# All analyses should execute after their respective routers
assert execution_order.index("diabetes_analysis") > execution_order.index(
"diabetes_router"
"diabetes_router",
)
assert execution_order.index("hypertension_analysis") > execution_order.index(
"hypertension_router"
"hypertension_router",
)
assert execution_order.index("anemia_analysis") > execution_order.index(
"anemia_router"
"anemia_router",
)

View File

@@ -1,15 +1,15 @@
"""Test that all public API classes are properly importable."""
def test_task_output_import():
def test_task_output_import() -> None:
"""Test that TaskOutput can be imported from crewai."""
from crewai import TaskOutput
assert TaskOutput is not None
def test_crew_output_import():
def test_crew_output_import() -> None:
"""Test that CrewOutput can be imported from crewai."""
from crewai import CrewOutput
assert CrewOutput is not None

View File

@@ -1,7 +1,6 @@
"""Test Knowledge creation and querying functionality."""
from pathlib import Path
from typing import List, Union
from unittest.mock import patch
import pytest
@@ -25,23 +24,23 @@ def mock_vector_db():
{
"context": "Brandon's favorite color is blue and he likes Mexican food.",
"score": 0.9,
}
},
]
instance.reset.return_value = None
yield instance
@pytest.fixture(autouse=True)
def reset_knowledge_storage(mock_vector_db):
def reset_knowledge_storage(mock_vector_db) -> None:
"""Fixture to reset knowledge storage before each test."""
yield
return
def test_single_short_string(mock_vector_db):
def test_single_short_string(mock_vector_db) -> None:
# Create a knowledge base with a single short string
content = "Brandon's favorite color is blue and he likes Mexican food."
string_source = StringKnowledgeSource(
content=content, metadata={"preference": "personal"}
content=content, metadata={"preference": "personal"},
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -56,7 +55,7 @@ def test_single_short_string(mock_vector_db):
# @pytest.mark.vcr(filter_headers=["authorization"])
def test_single_2k_character_string(mock_vector_db):
def test_single_2k_character_string(mock_vector_db) -> None:
# Create a 2k character string with various facts about Brandon
content = (
"Brandon is a software engineer who lives in San Francisco. "
@@ -81,7 +80,7 @@ def test_single_2k_character_string(mock_vector_db):
"He is also a fan of the Golden State Warriors and enjoys watching their games. "
)
string_source = StringKnowledgeSource(
content=content, metadata={"preference": "personal"}
content=content, metadata={"preference": "personal"},
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -95,7 +94,7 @@ def test_single_2k_character_string(mock_vector_db):
mock_vector_db.query.assert_called_once()
def test_multiple_short_strings(mock_vector_db):
def test_multiple_short_strings(mock_vector_db) -> None:
# Create multiple short string sources
contents = [
"Brandon loves hiking.",
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
# Mock the vector db query response
mock_vector_db.query.return_value = [
{"context": "Brandon has a dog named Max.", "score": 0.9}
{"context": "Brandon has a dog named Max.", "score": 0.9},
]
mock_vector_db.sources = string_sources
@@ -124,7 +123,7 @@ def test_multiple_short_strings(mock_vector_db):
mock_vector_db.query.assert_called_once()
def test_multiple_2k_character_strings(mock_vector_db):
def test_multiple_2k_character_strings(mock_vector_db) -> None:
# Create multiple 2k character strings with various facts about Brandon
contents = [
(
@@ -194,7 +193,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
mock_vector_db.query.assert_called_once()
def test_single_short_file(mock_vector_db, tmpdir):
def test_single_short_file(mock_vector_db, tmpdir) -> None:
# Create a single short text file
content = "Brandon's favorite sport is basketball."
file_path = Path(tmpdir.join("short_file.txt"))
@@ -202,7 +201,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
f.write(content)
file_source = TextFileKnowledgeSource(
file_paths=[file_path], metadata={"preference": "personal"}
file_paths=[file_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -215,7 +214,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_single_2k_character_file(mock_vector_db, tmpdir):
def test_single_2k_character_file(mock_vector_db, tmpdir) -> None:
# Create a single 2k character text file with various facts about Brandon
content = (
"Brandon is a software engineer who lives in San Francisco. "
@@ -244,7 +243,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
f.write(content)
file_source = TextFileKnowledgeSource(
file_paths=[file_path], metadata={"preference": "personal"}
file_paths=[file_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -257,7 +256,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_multiple_short_files(mock_vector_db, tmpdir):
def test_multiple_short_files(mock_vector_db, tmpdir) -> None:
# Create multiple short text files
contents = [
{
@@ -286,7 +285,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
]
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{"context": "Brandon lives in New York.", "score": 0.9}
{"context": "Brandon lives in New York.", "score": 0.9},
]
# Perform a query
query = "What city does he reside in?"
@@ -296,7 +295,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_multiple_2k_character_files(mock_vector_db, tmpdir):
def test_multiple_2k_character_files(mock_vector_db, tmpdir) -> None:
# Create multiple 2k character text files with various facts about Brandon
contents = [
(
@@ -362,7 +361,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
{
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"score": 0.9,
}
},
]
# Perform a query
query = "What is Brandon's favorite book?"
@@ -377,7 +376,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_hybrid_string_and_files(mock_vector_db, tmpdir):
def test_hybrid_string_and_files(mock_vector_db, tmpdir) -> None:
# Create string sources
string_contents = [
"Brandon is learning French.",
@@ -418,7 +417,7 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_pdf_knowledge_source(mock_vector_db):
def test_pdf_knowledge_source(mock_vector_db) -> None:
# Get the directory of the current file
current_dir = Path(__file__).parent
# Construct the path to the PDF file
@@ -426,11 +425,11 @@ def test_pdf_knowledge_source(mock_vector_db):
# Create a PDFKnowledgeSource
pdf_source = PDFKnowledgeSource(
file_paths=[pdf_path], metadata={"preference": "personal"}
file_paths=[pdf_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [
{"context": "crewai create crew latest-ai-development", "score": 0.9}
{"context": "crewai create crew latest-ai-development", "score": 0.9},
]
# Perform a query
@@ -446,9 +445,8 @@ def test_pdf_knowledge_source(mock_vector_db):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_csv_knowledge_source(mock_vector_db, tmpdir):
def test_csv_knowledge_source(mock_vector_db, tmpdir) -> None:
"""Test CSVKnowledgeSource with a simple CSV file."""
# Create a CSV file with sample data
csv_content = [
["Name", "Age", "City"],
@@ -463,11 +461,11 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
# Create a CSVKnowledgeSource
csv_source = CSVKnowledgeSource(
file_paths=[csv_path], metadata={"preference": "personal"}
file_paths=[csv_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"context": "Brandon is 30 years old.", "score": 0.9},
]
# Perform a query
@@ -479,16 +477,15 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_json_knowledge_source(mock_vector_db, tmpdir):
def test_json_knowledge_source(mock_vector_db, tmpdir) -> None:
"""Test JSONKnowledgeSource with a simple JSON file."""
# Create a JSON file with sample data
json_data = {
"people": [
{"name": "Brandon", "age": 30, "city": "New York"},
{"name": "Alice", "age": 25, "city": "Los Angeles"},
{"name": "Bob", "age": 35, "city": "Chicago"},
]
],
}
json_path = Path(tmpdir.join("data.json"))
with open(json_path, "w", encoding="utf-8") as f:
@@ -498,11 +495,11 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
# Create a JSONKnowledgeSource
json_source = JSONKnowledgeSource(
file_paths=[json_path], metadata={"preference": "personal"}
file_paths=[json_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [
{"context": "Alice lives in Los Angeles.", "score": 0.9}
{"context": "Alice lives in Los Angeles.", "score": 0.9},
]
# Perform a query
@@ -514,9 +511,8 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_excel_knowledge_source(mock_vector_db, tmpdir):
def test_excel_knowledge_source(mock_vector_db, tmpdir) -> None:
"""Test ExcelKnowledgeSource with a simple Excel file."""
# Create an Excel file with sample data
import pandas as pd
@@ -531,11 +527,11 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
# Create an ExcelKnowledgeSource
excel_source = ExcelKnowledgeSource(
file_paths=[excel_path], metadata={"preference": "personal"}
file_paths=[excel_path], metadata={"preference": "personal"},
)
mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"context": "Brandon is 30 years old.", "score": 0.9},
]
# Perform a query
@@ -548,7 +544,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
@pytest.mark.vcr
def test_docling_source(mock_vector_db):
def test_docling_source(mock_vector_db) -> None:
docling_source = CrewDoclingSource(
file_paths=[
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
@@ -559,7 +555,7 @@ def test_docling_source(mock_vector_db):
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
},
]
# Perform a query
query = "What is reward hacking?"
@@ -569,8 +565,8 @@ def test_docling_source(mock_vector_db):
@pytest.mark.vcr
def test_multiple_docling_sources():
urls: List[Union[Path, str]] = [
def test_multiple_docling_sources() -> None:
urls: list[Path | str] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
]
@@ -580,7 +576,7 @@ def test_multiple_docling_sources():
assert docling_source.content is not None
def test_file_path_validation():
def test_file_path_validation() -> None:
"""Test file path validation for knowledge sources."""
current_dir = Path(__file__).parent
pdf_path = current_dir / "crewai_quickstart.pdf"

View File

@@ -2,7 +2,6 @@ import os
from time import sleep
from unittest.mock import MagicMock, patch
import litellm
import pytest
from pydantic import BaseModel
@@ -17,29 +16,25 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_callback_replacement():
def test_llm_callback_replacement() -> None:
llm1 = LLM(model="gpt-4o-mini")
llm2 = LLM(model="gpt-4o-mini")
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
result1 = llm1.call(
llm1.call(
messages=[{"role": "user", "content": "Hello, world!"}],
callbacks=[calc_handler_1],
)
print("result1:", result1)
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
print("usage_metrics_1:", usage_metrics_1)
result2 = llm2.call(
llm2.call(
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
print("result2:", result2)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
print("usage_metrics_2:", usage_metrics_2)
# The first handler should not have been updated
assert usage_metrics_1.successful_requests == 1
@@ -48,7 +43,7 @@ def test_llm_callback_replacement():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_string_input():
def test_llm_call_with_string_input() -> None:
llm = LLM(model="gpt-4o-mini")
# Test the call method with a string input
@@ -58,7 +53,7 @@ def test_llm_call_with_string_input():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_string_input_and_callbacks():
def test_llm_call_with_string_input_and_callbacks() -> None:
llm = LLM(model="gpt-4o-mini")
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
@@ -75,7 +70,7 @@ def test_llm_call_with_string_input_and_callbacks():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_message_list():
def test_llm_call_with_message_list() -> None:
llm = LLM(model="gpt-4o-mini")
messages = [{"role": "user", "content": "What is the capital of France?"}]
@@ -86,7 +81,7 @@ def test_llm_call_with_message_list():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_tool_and_string_input():
def test_llm_call_with_tool_and_string_input() -> None:
llm = LLM(model="gpt-4o-mini")
def get_current_year() -> str:
@@ -124,7 +119,7 @@ def test_llm_call_with_tool_and_string_input():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_tool_and_message_list():
def test_llm_call_with_tool_and_message_list() -> None:
llm = LLM(model="gpt-4o-mini")
def square_number(number: int) -> int:
@@ -140,7 +135,7 @@ def test_llm_call_with_tool_and_message_list():
"parameters": {
"type": "object",
"properties": {
"number": {"type": "integer", "description": "The number to square"}
"number": {"type": "integer", "description": "The number to square"},
},
"required": ["number"],
},
@@ -164,7 +159,7 @@ def test_llm_call_with_tool_and_message_list():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_passes_additional_params():
def test_llm_passes_additional_params() -> None:
llm = LLM(
model="gpt-4o-mini",
vertex_credentials="test_credentials",
@@ -210,35 +205,35 @@ def test_llm_passes_additional_params():
assert result == "Test response"
def test_get_custom_llm_provider_openrouter():
def test_get_custom_llm_provider_openrouter() -> None:
llm = LLM(model="openrouter/deepseek/deepseek-chat")
assert llm._get_custom_llm_provider() == "openrouter"
def test_get_custom_llm_provider_gemini():
def test_get_custom_llm_provider_gemini() -> None:
llm = LLM(model="gemini/gemini-1.5-pro")
assert llm._get_custom_llm_provider() == "gemini"
def test_get_custom_llm_provider_openai():
def test_get_custom_llm_provider_openai() -> None:
llm = LLM(model="gpt-4")
assert llm._get_custom_llm_provider() == None
assert llm._get_custom_llm_provider() is None
def test_validate_call_params_supported():
def test_validate_call_params_supported() -> None:
class DummyResponse(BaseModel):
a: int
# Patch supports_response_schema to simulate a supported model.
with patch("crewai.llm.supports_response_schema", return_value=True):
llm = LLM(
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse,
)
# Should not raise any error.
llm._validate_call_params()
def test_validate_call_params_not_supported():
def test_validate_call_params_not_supported() -> None:
class DummyResponse(BaseModel):
a: int
@@ -250,7 +245,7 @@ def test_validate_call_params_not_supported():
assert "does not support response_format" in str(excinfo.value)
def test_validate_call_params_no_response_format():
def test_validate_call_params_no_response_format() -> None:
# When no response_format is provided, no validation error should occur.
llm = LLM(model="gemini/gemini-1.5-pro", response_format=None)
llm._validate_call_params()
@@ -267,7 +262,7 @@ def test_validate_call_params_no_response_format():
"gemini/gemini-2.5-pro-exp-03-25",
],
)
def test_gemini_models(model):
def test_gemini_models(model) -> None:
llm = LLM(model=model)
result = llm.call("What is the capital of France?")
assert isinstance(result, str)
@@ -284,7 +279,7 @@ def test_gemini_models(model):
"gemini/gemma-3-27b-it",
],
)
def test_gemma3(model):
def test_gemma3(model) -> None:
llm = LLM(model=model)
result = llm.call("What is the capital of France?")
assert isinstance(result, str)
@@ -293,9 +288,9 @@ def test_gemma3(model):
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.parametrize(
"model", ["gpt-4.1", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14"]
"model", ["gpt-4.1", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14"],
)
def test_gpt_4_1(model):
def test_gpt_4_1(model) -> None:
llm = LLM(model=model)
result = llm.call("What is the capital of France?")
assert isinstance(result, str)
@@ -303,7 +298,7 @@ def test_gpt_4_1(model):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_o3_mini_reasoning_effort_high():
def test_o3_mini_reasoning_effort_high() -> None:
llm = LLM(
model="o3-mini",
reasoning_effort="high",
@@ -314,7 +309,7 @@ def test_o3_mini_reasoning_effort_high():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_o3_mini_reasoning_effort_low():
def test_o3_mini_reasoning_effort_low() -> None:
llm = LLM(
model="o3-mini",
reasoning_effort="low",
@@ -325,7 +320,7 @@ def test_o3_mini_reasoning_effort_low():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_o3_mini_reasoning_effort_medium():
def test_o3_mini_reasoning_effort_medium() -> None:
llm = LLM(
model="o3-mini",
reasoning_effort="medium",
@@ -335,21 +330,20 @@ def test_o3_mini_reasoning_effort_medium():
assert "Paris" in result
def test_context_window_validation():
def test_context_window_validation() -> None:
"""Test that context window validation works correctly."""
# Test valid window size
llm = LLM(model="o3-mini")
assert llm.get_context_window_size() == int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
# Test invalid window size
with pytest.raises(ValueError) as excinfo:
with patch.dict(
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
{"test-model": 500}, # Below minimum
clear=True,
):
llm = LLM(model="test-model")
llm.get_context_window_size()
with pytest.raises(ValueError) as excinfo, patch.dict(
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
{"test-model": 500}, # Below minimum
clear=True,
):
llm = LLM(model="test-model")
llm.get_context_window_size()
assert "must be between 1024 and 2097152" in str(excinfo.value)
@@ -366,14 +360,14 @@ def get_weather_tool_schema():
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
},
"required": ["location"],
},
},
}
def test_context_window_exceeded_error_handling():
def test_context_window_exceeded_error_handling() -> None:
"""Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededException."""
from litellm.exceptions import ContextWindowExceededError
@@ -388,7 +382,7 @@ def test_context_window_exceeded_error_handling():
mock_completion.side_effect = ContextWindowExceededError(
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
model="gpt-4",
llm_provider="openai"
llm_provider="openai",
)
with pytest.raises(LLMContextLengthExceededException) as excinfo:
@@ -403,7 +397,7 @@ def test_context_window_exceeded_error_handling():
mock_completion.side_effect = ContextWindowExceededError(
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
model="gpt-4",
llm_provider="openai"
llm_provider="openai",
)
with pytest.raises(LLMContextLengthExceededException) as excinfo:
@@ -432,7 +426,7 @@ def user_message():
return {"role": "user", "content": "test"}
def test_anthropic_message_formatting_edge_cases(anthropic_llm):
def test_anthropic_message_formatting_edge_cases(anthropic_llm) -> None:
"""Test edge cases for Anthropic message formatting."""
# Test None messages
with pytest.raises(TypeError, match="Messages cannot be None"):
@@ -449,7 +443,7 @@ def test_anthropic_message_formatting_edge_cases(anthropic_llm):
anthropic_llm._format_messages_for_provider([{"invalid": "message"}])
def test_anthropic_model_detection():
def test_anthropic_model_detection() -> None:
"""Test Anthropic model detection with various formats."""
models = [
("anthropic/claude-3", True),
@@ -465,7 +459,7 @@ def test_anthropic_model_detection():
assert llm.is_anthropic == expected, f"Failed for model: {model}"
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message) -> None:
"""Test Anthropic message formatting with fixtures."""
# Test when first message is system
formatted = anthropic_llm._format_messages_for_provider([system_message])
@@ -492,7 +486,7 @@ def test_anthropic_message_formatting(anthropic_llm, system_message, user_messag
assert formatted[0] == system_message
def test_deepseek_r1_with_open_router():
def test_deepseek_r1_with_open_router() -> None:
if not os.getenv("OPEN_ROUTER_API_KEY"):
pytest.skip("OPEN_ROUTER_API_KEY not set; skipping test.")
@@ -512,7 +506,7 @@ def assert_event_count(
expected_stream_chunk: int = 0,
expected_completed_llm_call: int = 0,
expected_final_chunk_result: str = "",
):
) -> None:
event_count = {
"completed_tool_call": 0,
"stream_chunk": 0,
@@ -553,7 +547,7 @@ def mock_emit() -> MagicMock:
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit):
def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit) -> None:
llm = LLM(model="openai/gpt-4o", stream=True)
response = llm.call(
messages=[
@@ -561,7 +555,7 @@ def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit):
],
tools=[get_weather_tool_schema],
available_functions={
"get_weather": lambda location: f"The weather in {location} is sunny"
"get_weather": lambda location: f"The weather in {location} is sunny",
},
)
assert response == "The weather in New York, NY is sunny"
@@ -580,8 +574,8 @@ def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_streaming_tool_calls_no_available_functions(
get_weather_tool_schema, mock_emit
):
get_weather_tool_schema, mock_emit,
) -> None:
llm = LLM(model="openai/gpt-4o", stream=True)
response = llm.call(
messages=[
@@ -600,7 +594,7 @@ def test_handle_streaming_tool_calls_no_available_functions(
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_streaming_tool_calls_no_tools(mock_emit):
def test_handle_streaming_tool_calls_no_tools(mock_emit) -> None:
llm = LLM(model="openai/gpt-4o", stream=True)
response = llm.call(
messages=[

View File

@@ -13,8 +13,7 @@ from crewai.task import Task
@pytest.fixture
def mock_mem0_memory():
mock_memory = MagicMock(spec=Memory)
return mock_memory
return MagicMock(spec=Memory)
@pytest.fixture
@@ -29,8 +28,7 @@ def patch_configure_mem0(mock_mem0_memory):
@pytest.fixture
def external_memory_with_mocked_config(patch_configure_mem0):
embedder_config = {"provider": "mem0"}
external_memory = ExternalMemory(embedder_config=embedder_config)
return external_memory
return ExternalMemory(embedder_config=embedder_config)
@pytest.fixture
@@ -49,7 +47,7 @@ def crew_with_external_memory(external_memory_with_mocked_config, patch_configur
agent=agent,
)
crew = Crew(
return Crew(
agents=[agent],
tasks=[task],
verbose=True,
@@ -58,12 +56,11 @@ def crew_with_external_memory(external_memory_with_mocked_config, patch_configur
external_memory=external_memory_with_mocked_config,
)
return crew
@pytest.fixture
def crew_with_external_memory_without_memory_flag(
external_memory_with_mocked_config, patch_configure_mem0
external_memory_with_mocked_config, patch_configure_mem0,
):
agent = Agent(
role="Researcher",
@@ -79,7 +76,7 @@ def crew_with_external_memory_without_memory_flag(
agent=agent,
)
crew = Crew(
return Crew(
agents=[agent],
tasks=[task],
verbose=True,
@@ -87,17 +84,16 @@ def crew_with_external_memory_without_memory_flag(
external_memory=external_memory_with_mocked_config,
)
return crew
def test_external_memory_initialization(external_memory_with_mocked_config):
def test_external_memory_initialization(external_memory_with_mocked_config) -> None:
assert external_memory_with_mocked_config is not None
assert isinstance(external_memory_with_mocked_config, ExternalMemory)
def test_external_memory_save(external_memory_with_mocked_config):
def test_external_memory_save(external_memory_with_mocked_config) -> None:
memory_item = ExternalMemoryItem(
value="test value", metadata={"task": "test_task"}, agent="test_agent"
value="test value", metadata={"task": "test_task"}, agent="test_agent",
)
with patch.object(ExternalMemory, "save") as mock_save:
@@ -114,51 +110,51 @@ def test_external_memory_save(external_memory_with_mocked_config):
)
def test_external_memory_reset(external_memory_with_mocked_config):
def test_external_memory_reset(external_memory_with_mocked_config) -> None:
with patch(
"crewai.memory.external.external_memory.ExternalMemory.reset"
"crewai.memory.external.external_memory.ExternalMemory.reset",
) as mock_reset:
external_memory_with_mocked_config.reset()
mock_reset.assert_called_once()
def test_external_memory_supported_storages():
def test_external_memory_supported_storages() -> None:
supported_storages = ExternalMemory.external_supported_storages()
assert "mem0" in supported_storages
assert callable(supported_storages["mem0"])
def test_external_memory_create_storage_invalid_provider():
def test_external_memory_create_storage_invalid_provider() -> None:
embedder_config = {"provider": "invalid_provider", "config": {}}
with pytest.raises(ValueError, match="Provider invalid_provider not supported"):
ExternalMemory.create_storage(None, embedder_config)
def test_external_memory_create_storage_missing_provider():
def test_external_memory_create_storage_missing_provider() -> None:
embedder_config = {"config": {}}
with pytest.raises(
ValueError, match="embedder_config must include a 'provider' key"
ValueError, match="embedder_config must include a 'provider' key",
):
ExternalMemory.create_storage(None, embedder_config)
def test_external_memory_create_storage_missing_config():
def test_external_memory_create_storage_missing_config() -> None:
with pytest.raises(ValueError, match="embedder_config is required"):
ExternalMemory.create_storage(None, None)
def test_crew_with_external_memory_initialization(crew_with_external_memory):
def test_crew_with_external_memory_initialization(crew_with_external_memory) -> None:
assert crew_with_external_memory._external_memory is not None
assert isinstance(crew_with_external_memory._external_memory, ExternalMemory)
assert crew_with_external_memory._external_memory.crew == crew_with_external_memory
@pytest.mark.parametrize("mem_type", ["external", "all"])
def test_crew_external_memory_reset(mem_type, crew_with_external_memory):
def test_crew_external_memory_reset(mem_type, crew_with_external_memory) -> None:
with patch(
"crewai.memory.external.external_memory.ExternalMemory.reset"
"crewai.memory.external.external_memory.ExternalMemory.reset",
) as mock_reset:
crew_with_external_memory.reset_memories(mem_type)
mock_reset.assert_called_once()
@@ -167,10 +163,10 @@ def test_crew_external_memory_reset(mem_type, crew_with_external_memory):
@pytest.mark.parametrize("mem_method", ["search", "save"])
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_external_memory_save_with_memory_flag(
mem_method, crew_with_external_memory
):
mem_method, crew_with_external_memory,
) -> None:
with patch(
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}"
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}",
) as mock_method:
crew_with_external_memory.kickoff()
assert mock_method.call_count > 0
@@ -179,27 +175,27 @@ def test_crew_external_memory_save_with_memory_flag(
@pytest.mark.parametrize("mem_method", ["search", "save"])
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_external_memory_save_using_crew_without_memory_flag(
mem_method, crew_with_external_memory_without_memory_flag
):
mem_method, crew_with_external_memory_without_memory_flag,
) -> None:
with patch(
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}"
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}",
) as mock_method:
crew_with_external_memory_without_memory_flag.kickoff()
assert mock_method.call_count > 0
def test_external_memory_custom_storage(crew_with_external_memory):
def test_external_memory_custom_storage(crew_with_external_memory) -> None:
class CustomStorage(Storage):
def __init__(self):
def __init__(self) -> None:
self.memories = []
def save(self, value, metadata=None, agent=None):
def save(self, value, metadata=None, agent=None) -> None:
self.memories.append({"value": value, "metadata": metadata, "agent": agent})
def search(self, query, limit=10, score_threshold=0.5):
return self.memories
def reset(self):
def reset(self) -> None:
self.memories = []
custom_storage = CustomStorage()

View File

@@ -6,11 +6,11 @@ from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@pytest.fixture
def long_term_memory():
"""Fixture to create a LongTermMemory instance"""
"""Fixture to create a LongTermMemory instance."""
return LongTermMemory()
def test_save_and_search(long_term_memory):
def test_save_and_search(long_term_memory) -> None:
memory = LongTermMemoryItem(
agent="test_agent",
task="test_task",

View File

@@ -11,7 +11,7 @@ from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
"""Fixture to create a ShortTermMemory instance."""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
@@ -28,7 +28,7 @@ def short_term_memory():
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_save_and_search(short_term_memory):
def test_save_and_search(short_term_memory) -> None:
memory = ShortTermMemoryItem(
data="""test value test value test value test value test value test value
test value test value test value test value test value test value
@@ -55,7 +55,7 @@ def test_save_and_search(short_term_memory):
"context": memory.data,
"metadata": {"agent": "test_agent"},
"score": 0.95,
}
},
]
with patch.object(ShortTermMemory, "search", return_value=expected_result):
find = short_term_memory.search("test value", score_threshold=0.01)[0]

View File

@@ -8,28 +8,27 @@ from crewai.memory.user.user_memory_item import UserMemoryItem
class MockCrew:
def __init__(self, memory_config):
def __init__(self, memory_config) -> None:
self.memory_config = memory_config
@pytest.fixture
def user_memory():
"""Fixture to create a UserMemory instance"""
"""Fixture to create a UserMemory instance."""
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {"user_id": "john"},
"user_memory" : {}
}
"user_memory" : {},
},
)
user_memory = MagicMock(spec=UserMemory)
with patch.object(Memory,'__new__',return_value=user_memory):
user_memory_instance = UserMemory(crew=crew)
return user_memory_instance
with patch.object(Memory,"__new__",return_value=user_memory):
return UserMemory(crew=crew)
def test_save_and_search(user_memory):
def test_save_and_search(user_memory) -> None:
memory = UserMemoryItem(
data="""test value test value test value test value test value test value
test value test value test value test value test value test value
@@ -42,13 +41,13 @@ def test_save_and_search(user_memory):
user_memory.save(
value=memory.data,
metadata=memory.metadata,
user=memory.user
user=memory.user,
)
mock_save.assert_called_once_with(
value=memory.data,
metadata=memory.metadata,
user=memory.user
user=memory.user,
)
expected_result = [
@@ -56,12 +55,12 @@ def test_save_and_search(user_memory):
"context": memory.data,
"metadata": {"agent": "test_agent"},
"score": 0.95,
}
},
]
expected_result = ["mocked_result"]
# Use patch.object to mock UserMemory's search method
with patch.object(UserMemory, 'search', return_value=expected_result) as mock_search:
with patch.object(UserMemory, "search", return_value=expected_result) as mock_search:
find = UserMemory.search("test value", score_threshold=0.01)[0]
mock_search.assert_called_once_with("test value", score_threshold=0.01)
assert find == expected_result[0]

View File

@@ -1,5 +1,3 @@
from typing import List
from unittest.mock import patch
import pytest
@@ -23,7 +21,7 @@ class SimpleCrew:
@agent
def simple_agent(self):
return Agent(
role="Simple Agent", goal="Simple Goal", backstory="Simple Backstory"
role="Simple Agent", goal="Simple Goal", backstory="Simple Backstory",
)
@task
@@ -44,8 +42,8 @@ class InternalCrew:
agents_config = "config/agents.yaml"
tasks_config = "config/tasks.yaml"
agents: List[BaseAgent]
tasks: List[Task]
agents: list[BaseAgent]
tasks: list[Task]
@llm
def local_llm(self):
@@ -87,7 +85,7 @@ class InternalCrew:
return Crew(agents=self.agents, tasks=self.tasks, verbose=True)
def test_agent_memoization():
def test_agent_memoization() -> None:
crew = SimpleCrew()
first_call_result = crew.simple_agent()
second_call_result = crew.simple_agent()
@@ -97,7 +95,7 @@ def test_agent_memoization():
), "Agent memoization is not working as expected"
def test_task_memoization():
def test_task_memoization() -> None:
crew = SimpleCrew()
first_call_result = crew.simple_task()
second_call_result = crew.simple_task()
@@ -107,7 +105,7 @@ def test_task_memoization():
), "Task memoization is not working as expected"
def test_crew_memoization():
def test_crew_memoization() -> None:
crew = InternalCrew()
first_call_result = crew.crew()
second_call_result = crew.crew()
@@ -117,7 +115,7 @@ def test_crew_memoization():
), "Crew references should point to the same object"
def test_task_name():
def test_task_name() -> None:
simple_task = SimpleCrew().simple_task()
assert (
simple_task.name == "simple_task"
@@ -129,7 +127,7 @@ def test_task_name():
), "Custom task name is not being set as expected"
def test_agent_function_calling_llm():
def test_agent_function_calling_llm() -> None:
crew = InternalCrew()
llm = crew.local_llm()
obj_llm_agent = crew.researcher()
@@ -143,7 +141,7 @@ def test_agent_function_calling_llm():
), "agent's function_calling_llm is incorrect"
def test_task_guardrail():
def test_task_guardrail() -> None:
crew = InternalCrew()
research_task = crew.research_task()
assert research_task.guardrail == "ensure each bullet contains its source"
@@ -153,7 +151,7 @@ def test_task_guardrail():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_before_kickoff_modification():
def test_before_kickoff_modification() -> None:
crew = InternalCrew()
inputs = {"topic": "LLMs"}
result = crew.crew().kickoff(inputs=inputs)
@@ -161,7 +159,7 @@ def test_before_kickoff_modification():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_after_kickoff_modification():
def test_after_kickoff_modification() -> None:
crew = InternalCrew()
# Assuming the crew execution returns a dict
result = crew.crew().kickoff({"topic": "LLMs"})
@@ -172,18 +170,18 @@ def test_after_kickoff_modification():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_before_kickoff_with_none_input():
def test_before_kickoff_with_none_input() -> None:
crew = InternalCrew()
crew.crew().kickoff(None)
# Test should pass without raising exceptions
@pytest.mark.vcr(filter_headers=["authorization"])
def test_multiple_before_after_kickoff():
def test_multiple_before_after_kickoff() -> None:
@CrewBase
class MultipleHooksCrew:
agents: List[BaseAgent]
tasks: List[Task]
agents: list[BaseAgent]
tasks: list[Task]
agents_config = "config/agents.yaml"
tasks_config = "config/tasks.yaml"

View File

@@ -1,14 +1,12 @@
"""Tests for deterministic fingerprints in CrewAI components."""
from datetime import datetime
import pytest
from crewai import Agent, Crew, Task
from crewai.security import Fingerprint, SecurityConfig
def test_basic_deterministic_fingerprint():
def test_basic_deterministic_fingerprint() -> None:
"""Test that deterministic fingerprints can be created with a seed."""
# Create two fingerprints with the same seed
seed = "test-deterministic-fingerprint"
@@ -22,7 +20,7 @@ def test_basic_deterministic_fingerprint():
assert fingerprint1.created_at != fingerprint2.created_at
def test_deterministic_fingerprint_with_metadata():
def test_deterministic_fingerprint_with_metadata() -> None:
"""Test that deterministic fingerprints can include metadata."""
seed = "test-with-metadata"
metadata = {"version": "1.0", "environment": "testing"}
@@ -42,7 +40,7 @@ def test_deterministic_fingerprint_with_metadata():
assert fingerprint.metadata != fingerprint2.metadata
def test_agent_with_deterministic_fingerprint():
def test_agent_with_deterministic_fingerprint() -> None:
"""Test using deterministic fingerprints with agents."""
# Create a security config with a deterministic fingerprint
seed = "agent-fingerprint-test"
@@ -54,7 +52,7 @@ def test_agent_with_deterministic_fingerprint():
role="Researcher",
goal="Research quantum computing",
backstory="Expert in quantum physics",
security_config=security_config
security_config=security_config,
)
# Create another agent with the same security config
@@ -62,7 +60,7 @@ def test_agent_with_deterministic_fingerprint():
role="Completely different role",
goal="Different goal",
backstory="Different backstory",
security_config=security_config
security_config=security_config,
)
# Both agents should have the same fingerprint UUID
@@ -75,7 +73,7 @@ def test_agent_with_deterministic_fingerprint():
assert agent1.fingerprint.uuid_str == original_fingerprint
def test_task_with_deterministic_fingerprint():
def test_task_with_deterministic_fingerprint() -> None:
"""Test using deterministic fingerprints with tasks."""
# Create a security config with a deterministic fingerprint
seed = "task-fingerprint-test"
@@ -86,7 +84,7 @@ def test_task_with_deterministic_fingerprint():
agent = Agent(
role="Assistant",
goal="Help with tasks",
backstory="Helpful AI assistant"
backstory="Helpful AI assistant",
)
# Create a task with the deterministic fingerprint
@@ -94,7 +92,7 @@ def test_task_with_deterministic_fingerprint():
description="Analyze data",
expected_output="Data analysis report",
agent=agent,
security_config=security_config
security_config=security_config,
)
# Create another task with the same security config
@@ -102,7 +100,7 @@ def test_task_with_deterministic_fingerprint():
description="Different task description",
expected_output="Different expected output",
agent=agent,
security_config=security_config
security_config=security_config,
)
# Both tasks should have the same fingerprint UUID
@@ -110,7 +108,7 @@ def test_task_with_deterministic_fingerprint():
assert task1.fingerprint.uuid_str == fingerprint.uuid_str
def test_crew_with_deterministic_fingerprint():
def test_crew_with_deterministic_fingerprint() -> None:
"""Test using deterministic fingerprints with crews."""
# Create a security config with a deterministic fingerprint
seed = "crew-fingerprint-test"
@@ -121,33 +119,33 @@ def test_crew_with_deterministic_fingerprint():
agent1 = Agent(
role="Researcher",
goal="Research information",
backstory="Expert researcher"
backstory="Expert researcher",
)
agent2 = Agent(
role="Writer",
goal="Write reports",
backstory="Expert writer"
backstory="Expert writer",
)
# Create a crew with the deterministic fingerprint
crew1 = Crew(
agents=[agent1, agent2],
tasks=[],
security_config=security_config
security_config=security_config,
)
# Create another crew with the same security config but different agents
agent3 = Agent(
role="Analyst",
goal="Analyze data",
backstory="Expert analyst"
backstory="Expert analyst",
)
crew2 = Crew(
agents=[agent3],
tasks=[],
security_config=security_config
security_config=security_config,
)
# Both crews should have the same fingerprint UUID
@@ -155,7 +153,7 @@ def test_crew_with_deterministic_fingerprint():
assert crew1.fingerprint.uuid_str == fingerprint.uuid_str
def test_recreating_components_with_same_seed():
def test_recreating_components_with_same_seed() -> None:
"""Test recreating components with the same seed across sessions."""
# This simulates using the same seed in different runs/sessions
@@ -168,7 +166,7 @@ def test_recreating_components_with_same_seed():
role="Researcher",
goal="Research topic",
backstory="Expert researcher",
security_config=security_config1
security_config=security_config1,
)
uuid_from_first_session = agent1.fingerprint.uuid_str
@@ -181,14 +179,14 @@ def test_recreating_components_with_same_seed():
role="Researcher",
goal="Research topic",
backstory="Expert researcher",
security_config=security_config2
security_config=security_config2,
)
# Should have same UUID across sessions
assert agent2.fingerprint.uuid_str == uuid_from_first_session
def test_security_config_with_seed_string():
def test_security_config_with_seed_string() -> None:
"""Test creating SecurityConfig with a seed string directly."""
# SecurityConfig can accept a string as fingerprint parameter
# which will be used as a seed to generate a deterministic fingerprint
@@ -209,14 +207,14 @@ def test_security_config_with_seed_string():
role="Tester",
goal="Test fingerprints",
backstory="Expert tester",
security_config=security_config
security_config=security_config,
)
# Agent should have the same fingerprint UUID
assert agent.fingerprint.uuid_str == expected_fingerprint.uuid_str
def test_complex_component_hierarchy_with_deterministic_fingerprints():
def test_complex_component_hierarchy_with_deterministic_fingerprints() -> None:
"""Test a complex hierarchy of components all using deterministic fingerprints."""
# Create a deterministic fingerprint for each component
agent_seed = "deterministic-agent-seed"
@@ -236,7 +234,7 @@ def test_complex_component_hierarchy_with_deterministic_fingerprints():
role="Complex Test Agent",
goal="Test complex fingerprint scenarios",
backstory="Expert in testing",
security_config=agent_config
security_config=agent_config,
)
# Create a task
@@ -244,14 +242,14 @@ def test_complex_component_hierarchy_with_deterministic_fingerprints():
description="Test complex fingerprinting",
expected_output="Verification of fingerprint stability",
agent=agent,
security_config=task_config
security_config=task_config,
)
# Create a crew
crew = Crew(
agents=[agent],
tasks=[task],
security_config=crew_config
security_config=crew_config,
)
# Each component should have its own deterministic fingerprint
@@ -271,4 +269,4 @@ def test_complex_component_hierarchy_with_deterministic_fingerprints():
assert agent_fingerprint.uuid_str == agent_fingerprint2.uuid_str
assert task_fingerprint.uuid_str == task_fingerprint2.uuid_str
assert crew_fingerprint.uuid_str == crew_fingerprint2.uuid_str
assert crew_fingerprint.uuid_str == crew_fingerprint2.uuid_str

View File

@@ -1,16 +1,15 @@
"""Test for the examples in the fingerprinting documentation."""
import pytest
from crewai import Agent, Crew, Task
from crewai.security import Fingerprint, SecurityConfig
def test_basic_usage_examples():
def test_basic_usage_examples() -> None:
"""Test the basic usage examples from the documentation."""
# Creating components with automatic fingerprinting
agent = Agent(
role="Data Scientist", goal="Analyze data", backstory="Expert in data analysis"
role="Data Scientist", goal="Analyze data", backstory="Expert in data analysis",
)
# Verify the agent has a fingerprint
@@ -35,11 +34,11 @@ def test_basic_usage_examples():
assert task.fingerprint.uuid_str is not None
def test_accessing_fingerprints_example():
def test_accessing_fingerprints_example() -> None:
"""Test the accessing fingerprints example from the documentation."""
# Create components
agent = Agent(
role="Data Scientist", goal="Analyze data", backstory="Expert in data analysis"
role="Data Scientist", goal="Analyze data", backstory="Expert in data analysis",
)
crew = Crew(agents=[agent], tasks=[])
@@ -75,11 +74,11 @@ def test_accessing_fingerprints_example():
task_fingerprint.uuid_str,
]
assert len(fingerprints) == len(
set(fingerprints)
set(fingerprints),
), "All fingerprints should be unique"
def test_fingerprint_metadata_example():
def test_fingerprint_metadata_example() -> None:
"""Test using the Fingerprint's metadata for additional information."""
# Create a SecurityConfig with custom metadata
security_config = SecurityConfig()
@@ -97,7 +96,7 @@ def test_fingerprint_metadata_example():
assert agent.fingerprint.metadata == {"version": "1.0", "author": "John Doe"}
def test_fingerprint_with_security_config():
def test_fingerprint_with_security_config() -> None:
"""Test example of using a SecurityConfig with components."""
# Create a SecurityConfig
security_config = SecurityConfig()
@@ -125,15 +124,15 @@ def test_fingerprint_with_security_config():
assert task.security_config is security_config
def test_complete_workflow_example():
def test_complete_workflow_example() -> None:
"""Test the complete workflow example from the documentation."""
# Create agents with auto-generated fingerprints
researcher = Agent(
role="Researcher", goal="Find information", backstory="Expert researcher"
role="Researcher", goal="Find information", backstory="Expert researcher",
)
writer = Agent(
role="Writer", goal="Create content", backstory="Professional writer"
role="Writer", goal="Create content", backstory="Professional writer",
)
# Create tasks with auto-generated fingerprints
@@ -151,7 +150,7 @@ def test_complete_workflow_example():
# Create a crew with auto-generated fingerprint
content_crew = Crew(
agents=[researcher, writer], tasks=[research_task, writing_task]
agents=[researcher, writer], tasks=[research_task, writing_task],
)
# Verify everything has auto-generated fingerprints
@@ -170,11 +169,11 @@ def test_complete_workflow_example():
content_crew.fingerprint.uuid_str,
]
assert len(fingerprints) == len(
set(fingerprints)
set(fingerprints),
), "All fingerprints should be unique"
def test_security_preservation_during_copy():
def test_security_preservation_during_copy() -> None:
"""Test that security configurations are preserved when copying Crew and Agent objects."""
# Create a SecurityConfig with custom metadata
security_config = SecurityConfig()
@@ -197,7 +196,7 @@ def test_security_preservation_during_copy():
# Create a crew with the agent and task
original_crew = Crew(
agents=[original_agent], tasks=[task], security_config=security_config
agents=[original_agent], tasks=[task], security_config=security_config,
)
# Copy the agent and crew

View File

@@ -5,12 +5,11 @@ import uuid
from datetime import datetime, timedelta
import pytest
from pydantic import ValidationError
from crewai.security import Fingerprint
def test_fingerprint_creation_with_defaults():
def test_fingerprint_creation_with_defaults() -> None:
"""Test creating a Fingerprint with default values."""
fingerprint = Fingerprint()
@@ -27,7 +26,7 @@ def test_fingerprint_creation_with_defaults():
assert fingerprint.metadata == {}
def test_fingerprint_creation_with_metadata():
def test_fingerprint_creation_with_metadata() -> None:
"""Test creating a Fingerprint with custom metadata only."""
metadata = {"version": "1.0", "author": "Test Author"}
@@ -40,7 +39,7 @@ def test_fingerprint_creation_with_metadata():
assert fingerprint.metadata == metadata
def test_fingerprint_uuid_cannot_be_set():
def test_fingerprint_uuid_cannot_be_set() -> None:
"""Test that uuid_str cannot be manually set."""
original_uuid = "b723c6ff-95de-5e87-860b-467b72282bd8"
@@ -52,7 +51,7 @@ def test_fingerprint_uuid_cannot_be_set():
assert uuid.UUID(fingerprint.uuid_str) # Should be a valid UUID
def test_fingerprint_created_at_cannot_be_set():
def test_fingerprint_created_at_cannot_be_set() -> None:
"""Test that created_at cannot be manually set."""
original_time = datetime.now() - timedelta(days=1)
@@ -64,7 +63,7 @@ def test_fingerprint_created_at_cannot_be_set():
assert fingerprint.created_at > original_time # Should be more recent
def test_fingerprint_uuid_property():
def test_fingerprint_uuid_property() -> None:
"""Test the uuid property returns a UUID object."""
fingerprint = Fingerprint()
@@ -72,7 +71,7 @@ def test_fingerprint_uuid_property():
assert str(fingerprint.uuid) == fingerprint.uuid_str
def test_fingerprint_deterministic_generation():
def test_fingerprint_deterministic_generation() -> None:
"""Test that the same seed string always generates the same fingerprint using generate method."""
seed = "test-seed"
@@ -88,7 +87,7 @@ def test_fingerprint_deterministic_generation():
assert uuid_str1 == uuid_str2
def test_fingerprint_generate_classmethod():
def test_fingerprint_generate_classmethod() -> None:
"""Test the generate class method."""
# Without seed
fingerprint1 = Fingerprint.generate()
@@ -107,7 +106,7 @@ def test_fingerprint_generate_classmethod():
assert fingerprint2.uuid_str == fingerprint3.uuid_str
def test_fingerprint_string_representation():
def test_fingerprint_string_representation() -> None:
"""Test the string representation of Fingerprint."""
fingerprint = Fingerprint()
uuid_str = fingerprint.uuid_str
@@ -116,7 +115,7 @@ def test_fingerprint_string_representation():
assert uuid_str in string_repr
def test_fingerprint_equality():
def test_fingerprint_equality() -> None:
"""Test fingerprint equality comparison."""
# Using generate with the same seed to get consistent UUIDs
seed = "test-equality"
@@ -129,7 +128,7 @@ def test_fingerprint_equality():
assert fingerprint1 != fingerprint3
def test_fingerprint_hash():
def test_fingerprint_hash() -> None:
"""Test that fingerprints can be used as dictionary keys."""
# Using generate with the same seed to get consistent UUIDs
seed = "test-hash"
@@ -145,7 +144,7 @@ def test_fingerprint_hash():
assert fingerprint_dict[fingerprint2] == "value"
def test_fingerprint_to_dict():
def test_fingerprint_to_dict() -> None:
"""Test converting fingerprint to dictionary."""
metadata = {"version": "1.0"}
fingerprint = Fingerprint(metadata=metadata)
@@ -160,7 +159,7 @@ def test_fingerprint_to_dict():
assert fingerprint_dict["metadata"] == metadata
def test_fingerprint_from_dict():
def test_fingerprint_from_dict() -> None:
"""Test creating fingerprint from dictionary."""
uuid_str = "b723c6ff-95de-5e87-860b-467b72282bd8"
created_at = datetime.now()
@@ -170,7 +169,7 @@ def test_fingerprint_from_dict():
fingerprint_dict = {
"uuid_str": uuid_str,
"created_at": created_at_iso,
"metadata": metadata
"metadata": metadata,
}
fingerprint = Fingerprint.from_dict(fingerprint_dict)
@@ -180,7 +179,7 @@ def test_fingerprint_from_dict():
assert fingerprint.metadata == metadata
def test_fingerprint_json_serialization():
def test_fingerprint_json_serialization() -> None:
"""Test that Fingerprint can be JSON serialized and deserialized."""
# Create a fingerprint, get its values
metadata = {"version": "1.0"}
@@ -202,7 +201,7 @@ def test_fingerprint_json_serialization():
assert new_fingerprint.metadata == metadata
def test_invalid_uuid_str():
def test_invalid_uuid_str() -> None:
"""Test handling of invalid UUID strings."""
uuid_str = "not-a-valid-uuid"
created_at = datetime.now().isoformat()
@@ -210,7 +209,7 @@ def test_invalid_uuid_str():
fingerprint_dict = {
"uuid_str": uuid_str,
"created_at": created_at,
"metadata": {}
"metadata": {},
}
# The Fingerprint.from_dict method accepts even invalid UUIDs
@@ -223,10 +222,10 @@ def test_invalid_uuid_str():
# But this will raise an exception when we try to access the uuid property
with pytest.raises(ValueError):
uuid_obj = fingerprint.uuid
pass
def test_fingerprint_metadata_mutation():
def test_fingerprint_metadata_mutation() -> None:
"""Test that metadata can be modified after fingerprint creation."""
# Create a fingerprint with initial metadata
initial_metadata = {"version": "1.0", "status": "draft"}
@@ -243,7 +242,7 @@ def test_fingerprint_metadata_mutation():
expected_metadata = {
"version": "1.0",
"status": "published",
"author": "Test Author"
"author": "Test Author",
}
assert fingerprint.metadata == expected_metadata
@@ -260,4 +259,4 @@ def test_fingerprint_metadata_mutation():
# Ensure immutable fields remain unchanged
assert fingerprint.uuid_str == uuid_str
assert fingerprint.created_at == created_at
assert fingerprint.created_at == created_at

View File

@@ -1,12 +1,11 @@
"""Test integration of fingerprinting with Agent, Crew, and Task classes."""
import pytest
from crewai import Agent, Crew, Task
from crewai.security import Fingerprint, SecurityConfig
def test_agent_with_security_config():
def test_agent_with_security_config() -> None:
"""Test creating an Agent with a SecurityConfig."""
# Create agent with SecurityConfig
security_config = SecurityConfig()
@@ -15,7 +14,7 @@ def test_agent_with_security_config():
role="Tester",
goal="Test fingerprinting",
backstory="Testing fingerprinting",
security_config=security_config
security_config=security_config,
)
assert agent.security_config is not None
@@ -24,13 +23,13 @@ def test_agent_with_security_config():
assert agent.fingerprint is not None
def test_agent_fingerprint_property():
def test_agent_fingerprint_property() -> None:
"""Test the fingerprint property on Agent."""
# Create agent without security_config
agent = Agent(
role="Tester",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
# Fingerprint should be automatically generated
@@ -39,7 +38,7 @@ def test_agent_fingerprint_property():
assert agent.security_config is not None
def test_crew_with_security_config():
def test_crew_with_security_config() -> None:
"""Test creating a Crew with a SecurityConfig."""
# Create crew with SecurityConfig
security_config = SecurityConfig()
@@ -47,18 +46,18 @@ def test_crew_with_security_config():
agent1 = Agent(
role="Tester1",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
agent2 = Agent(
role="Tester2",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
crew = Crew(
agents=[agent1, agent2],
security_config=security_config
security_config=security_config,
)
assert crew.security_config is not None
@@ -67,19 +66,19 @@ def test_crew_with_security_config():
assert crew.fingerprint is not None
def test_crew_fingerprint_property():
def test_crew_fingerprint_property() -> None:
"""Test the fingerprint property on Crew."""
# Create crew without security_config
agent1 = Agent(
role="Tester1",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
agent2 = Agent(
role="Tester2",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
crew = Crew(agents=[agent1, agent2])
@@ -90,7 +89,7 @@ def test_crew_fingerprint_property():
assert crew.security_config is not None
def test_task_with_security_config():
def test_task_with_security_config() -> None:
"""Test creating a Task with a SecurityConfig."""
# Create task with SecurityConfig
security_config = SecurityConfig()
@@ -98,14 +97,14 @@ def test_task_with_security_config():
agent = Agent(
role="Tester",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
task = Task(
description="Test task",
expected_output="Testing output",
agent=agent,
security_config=security_config
security_config=security_config,
)
assert task.security_config is not None
@@ -114,19 +113,19 @@ def test_task_with_security_config():
assert task.fingerprint is not None
def test_task_fingerprint_property():
def test_task_fingerprint_property() -> None:
"""Test the fingerprint property on Task."""
# Create task without security_config
agent = Agent(
role="Tester",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
task = Task(
description="Test task",
expected_output="Testing output",
agent=agent
agent=agent,
)
# Fingerprint should be automatically generated
@@ -135,36 +134,36 @@ def test_task_fingerprint_property():
assert task.security_config is not None
def test_end_to_end_fingerprinting():
def test_end_to_end_fingerprinting() -> None:
"""Test end-to-end fingerprinting across Agent, Crew, and Task."""
# Create components with auto-generated fingerprints
agent1 = Agent(
role="Researcher",
goal="Research information",
backstory="Expert researcher"
backstory="Expert researcher",
)
agent2 = Agent(
role="Writer",
goal="Write content",
backstory="Expert writer"
backstory="Expert writer",
)
task1 = Task(
description="Research topic",
expected_output="Research findings",
agent=agent1
agent=agent1,
)
task2 = Task(
description="Write article",
expected_output="Written article",
agent=agent2
agent=agent2,
)
crew = Crew(
agents=[agent1, agent2],
tasks=[task1, task2]
tasks=[task1, task2],
)
# Verify all fingerprints were automatically generated
@@ -180,18 +179,18 @@ def test_end_to_end_fingerprinting():
agent2.fingerprint.uuid_str,
task1.fingerprint.uuid_str,
task2.fingerprint.uuid_str,
crew.fingerprint.uuid_str
crew.fingerprint.uuid_str,
]
assert len(fingerprints) == len(set(fingerprints)), "All fingerprints should be unique"
def test_fingerprint_persistence():
def test_fingerprint_persistence() -> None:
"""Test that fingerprints persist and don't change."""
# Create an agent and check its fingerprint
agent = Agent(
role="Tester",
goal="Test fingerprinting",
backstory="Testing fingerprinting"
backstory="Testing fingerprinting",
)
# Get initial fingerprint
@@ -204,7 +203,7 @@ def test_fingerprint_persistence():
task = Task(
description="Test task",
expected_output="Testing output",
agent=agent
agent=agent,
)
# Check that task has its own unique fingerprint
@@ -212,7 +211,7 @@ def test_fingerprint_persistence():
assert task.fingerprint.uuid_str != agent.fingerprint.uuid_str
def test_shared_security_config_fingerprints():
def test_shared_security_config_fingerprints() -> None:
"""Test that components with the same SecurityConfig share the same fingerprint."""
# Create a shared SecurityConfig
shared_security_config = SecurityConfig()
@@ -223,27 +222,27 @@ def test_shared_security_config_fingerprints():
role="Researcher",
goal="Research information",
backstory="Expert researcher",
security_config=shared_security_config
security_config=shared_security_config,
)
agent2 = Agent(
role="Writer",
goal="Write content",
backstory="Expert writer",
security_config=shared_security_config
security_config=shared_security_config,
)
task = Task(
description="Write article",
expected_output="Written article",
agent=agent1,
security_config=shared_security_config
security_config=shared_security_config,
)
crew = Crew(
agents=[agent1, agent2],
tasks=[task],
security_config=shared_security_config
security_config=shared_security_config,
)
# Verify all components have the same fingerprint UUID
@@ -256,4 +255,4 @@ def test_shared_security_config_fingerprints():
assert agent1.fingerprint is shared_security_config.fingerprint
assert agent2.fingerprint is shared_security_config.fingerprint
assert task.fingerprint is shared_security_config.fingerprint
assert crew.fingerprint is shared_security_config.fingerprint
assert crew.fingerprint is shared_security_config.fingerprint

View File

@@ -6,7 +6,7 @@ from datetime import datetime
from crewai.security import Fingerprint, SecurityConfig
def test_security_config_creation_with_defaults():
def test_security_config_creation_with_defaults() -> None:
"""Test creating a SecurityConfig with default values."""
config = SecurityConfig()
@@ -16,7 +16,7 @@ def test_security_config_creation_with_defaults():
assert config.fingerprint.uuid_str is not None # UUID is auto-generated
def test_security_config_fingerprint_generation():
def test_security_config_fingerprint_generation() -> None:
"""Test that SecurityConfig automatically generates fingerprints."""
config = SecurityConfig()
@@ -27,7 +27,7 @@ def test_security_config_fingerprint_generation():
assert len(config.fingerprint.uuid_str) > 0
def test_security_config_init_params():
def test_security_config_init_params() -> None:
"""Test that SecurityConfig can be initialized and modified."""
# Create a config
config = SecurityConfig()
@@ -43,7 +43,7 @@ def test_security_config_init_params():
assert config.fingerprint.metadata == {"version": "1.0"}
def test_security_config_to_dict():
def test_security_config_to_dict() -> None:
"""Test converting SecurityConfig to dictionary."""
# Create a config with a fingerprint that has metadata
config = SecurityConfig()
@@ -57,19 +57,16 @@ def test_security_config_to_dict():
assert config_dict["fingerprint"]["metadata"] == {"version": "1.0"}
def test_security_config_from_dict():
def test_security_config_from_dict() -> None:
"""Test creating SecurityConfig from dictionary."""
# Create a fingerprint dict
fingerprint_dict = {
"uuid_str": "b723c6ff-95de-5e87-860b-467b72282bd8",
"created_at": datetime.now().isoformat(),
"metadata": {"version": "1.0"}
"metadata": {"version": "1.0"},
}
# Create a config dict with just the fingerprint
config_dict = {
"fingerprint": fingerprint_dict
}
# Create config manually since from_dict has a specific implementation
config = SecurityConfig()
@@ -85,7 +82,7 @@ def test_security_config_from_dict():
assert config.fingerprint.metadata == fingerprint_dict["metadata"]
def test_security_config_json_serialization():
def test_security_config_json_serialization() -> None:
"""Test that SecurityConfig can be JSON serialized and deserialized."""
# Create a config with fingerprint metadata
config = SecurityConfig()
@@ -115,4 +112,4 @@ def test_security_config_json_serialization():
new_config.fingerprint = new_fingerprint
# Check the new config has the same fingerprint metadata
assert new_config.fingerprint.metadata == {"version": "1.0"}
assert new_config.fingerprint.metadata == {"version": "1.0"}

View File

@@ -1,34 +1,28 @@
import os
from unittest.mock import MagicMock, patch
import pytest
from mem0.client.main import MemoryClient
from mem0.memory.main import Memory
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.memory.storage.mem0_storage import Mem0Storage
from crewai.task import Task
# Define the class (if not already defined)
class MockCrew:
def __init__(self, memory_config):
def __init__(self, memory_config) -> None:
self.memory_config = memory_config
self.agents = [MagicMock(role="Test Agent")]
@pytest.fixture
def mock_mem0_memory():
"""Fixture to create a mock Memory instance"""
mock_memory = MagicMock(spec=Memory)
return mock_memory
"""Fixture to create a mock Memory instance."""
return MagicMock(spec=Memory)
@pytest.fixture
def mem0_storage_with_mocked_config(mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
"""Fixture to create a Mem0Storage instance with mocked dependencies."""
# Patch the Memory class to return our mock
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
config = {
@@ -63,15 +57,15 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
memory_config={
"provider": "mem0",
"config": {"user_id": "test_user", "local_mem0_config": config},
}
},
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage, mock_from_config, config
def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_memory):
"""Test that Mem0Storage initializes correctly with the mocked config"""
def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_memory) -> None:
"""Test that Mem0Storage initializes correctly with the mocked config."""
mem0_storage, mock_from_config, config = mem0_storage_with_mocked_config
assert mem0_storage.memory_type == "short_term"
assert mem0_storage.memory is mock_mem0_memory
@@ -80,15 +74,13 @@ def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_
@pytest.fixture
def mock_mem0_memory_client():
"""Fixture to create a mock MemoryClient instance"""
mock_memory = MagicMock(spec=MemoryClient)
return mock_memory
"""Fixture to create a mock MemoryClient instance."""
return MagicMock(spec=MemoryClient)
@pytest.fixture
def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_client):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
"""Fixture to create a Mem0Storage instance with mocked dependencies."""
# We need to patch the MemoryClient before it's instantiated
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
@@ -100,17 +92,15 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
"org_id": "my_org_id",
"project_id": "my_project_id",
},
}
},
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage
return Mem0Storage(type="short_term", crew=crew)
@pytest.fixture
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
"""Fixture to create a Mem0Storage instance with mocked dependencies."""
# We need to patch both MemoryClient and Memory to prevent actual initialization
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
@@ -124,19 +114,18 @@ def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_clie
"org_id": "my_org_id",
"project_id": "my_project_id",
},
}
},
)
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=new_config)
return mem0_storage
return Mem0Storage(type="short_term", crew=crew, config=new_config)
def test_mem0_storage_with_memory_client_initialization(
mem0_storage_with_memory_client_using_config_from_crew, mock_mem0_memory_client
):
"""Test Mem0Storage initialization with MemoryClient"""
mem0_storage_with_memory_client_using_config_from_crew, mock_mem0_memory_client,
) -> None:
"""Test Mem0Storage initialization with MemoryClient."""
assert (
mem0_storage_with_memory_client_using_config_from_crew.memory_type
== "short_term"
@@ -149,7 +138,7 @@ def test_mem0_storage_with_memory_client_initialization(
def test_mem0_storage_with_explict_config(
mem0_storage_with_memory_client_using_explictly_config,
):
) -> None:
expected_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
assert (
mem0_storage_with_memory_client_using_explictly_config.config == expected_config
@@ -160,17 +149,17 @@ def test_mem0_storage_with_explict_config(
)
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config) -> None:
"""Test save method for different memory types."""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
test_value,
agent_id="Test_Agent",
@@ -179,28 +168,28 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
)
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
"""Test save method for different memory types"""
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew) -> None:
"""Test save method for different memory types."""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
test_value,
agent_id="Test_Agent",
infer=False,
metadata={"type": "short_term", "key": "value"},
output_format="v1.1"
output_format="v1.1",
)
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test search method for different memory types"""
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config) -> None:
"""Test search method for different memory types."""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
@@ -208,18 +197,18 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
query="test query",
limit=5,
agent_id="Test_Agent",
user_id="test_user"
user_id="test_user",
)
assert len(results) == 1
assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
"""Test search method for different memory types"""
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew) -> None:
"""Test search method for different memory types."""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
@@ -227,12 +216,12 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
agent_id="Test_Agent",
query="test query",
limit=5,
agent_id="Test_Agent",
metadata={"type": "short_term"},
user_id="test_user",
output_format='v1.1'
output_format="v1.1",
)
assert len(results) == 1

View File

@@ -5,7 +5,6 @@ import json
import os
import time
from functools import partial
from typing import Tuple, Union
from unittest.mock import MagicMock, patch
import pytest
@@ -19,12 +18,12 @@ from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
def test_task_tool_reflect_agent_tools():
def test_task_tool_reflect_agent_tools() -> None:
from crewai.tools import tool
@tool
def fake_tool() -> None:
"Fake tool"
"""Fake tool."""
researcher = Agent(
role="Researcher",
@@ -43,16 +42,16 @@ def test_task_tool_reflect_agent_tools():
assert task.tools == [fake_tool]
def test_task_tool_takes_precedence_over_agent_tools():
def test_task_tool_takes_precedence_over_agent_tools() -> None:
from crewai.tools import tool
@tool
def fake_tool() -> None:
"Fake tool"
"""Fake tool."""
@tool
def fake_task_tool() -> None:
"Fake tool"
"""Fake tool."""
researcher = Agent(
role="Researcher",
@@ -72,7 +71,7 @@ def test_task_tool_takes_precedence_over_agent_tools():
assert task.tools == [fake_task_tool]
def test_task_prompt_includes_expected_output():
def test_task_prompt_includes_expected_output() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
@@ -92,7 +91,7 @@ def test_task_prompt_includes_expected_output():
execute.assert_called_once_with(task=task, context=None, tools=[])
def test_task_callback():
def test_task_callback() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
@@ -120,7 +119,7 @@ def test_task_callback():
assert task.output.name == task.name
def test_task_callback_returns_task_output():
def test_task_callback_returns_task_output() -> None:
from crewai.tasks.output_format import OutputFormat
researcher = Agent(
@@ -166,7 +165,7 @@ def test_task_callback_returns_task_output():
assert output_dict == expected_output
def test_execute_with_agent():
def test_execute_with_agent() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
@@ -184,7 +183,7 @@ def test_execute_with_agent():
execute.assert_called_once_with(task=task, context=None, tools=[])
def test_async_execution():
def test_async_execution() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
@@ -206,7 +205,7 @@ def test_async_execution():
execute.assert_called_once_with(task=task, context=None, tools=[])
def test_multiple_output_type_error():
def test_multiple_output_type_error() -> None:
class Output(BaseModel):
field: str
@@ -219,7 +218,7 @@ def test_multiple_output_type_error():
)
def test_guardrail_type_error():
def test_guardrail_type_error() -> None:
desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting."
expected_output = "Bullet point list of 5 interesting ideas."
# Lambda function
@@ -248,7 +247,7 @@ def test_guardrail_type_error():
return (True, x)
@staticmethod
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, str | TaskOutput]:
return (True, x)
obj = Object()
@@ -271,7 +270,7 @@ def test_guardrail_type_error():
guardrail=Object.guardrail_static_fn,
)
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
return (y, x)
Task(
@@ -289,7 +288,7 @@ def test_guardrail_type_error():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_pydantic_sequential():
def test_output_pydantic_sequential() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -314,7 +313,7 @@ def test_output_pydantic_sequential():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_pydantic_hierarchical():
def test_output_pydantic_hierarchical() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -344,7 +343,7 @@ def test_output_pydantic_hierarchical():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_json_sequential():
def test_output_json_sequential() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -365,12 +364,12 @@ def test_output_json_sequential():
crew = Crew(agents=[scorer], tasks=[task], process=Process.sequential)
result = crew.kickoff()
assert '{"score": 4}' == result.json
assert result.json == '{"score": 4}'
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_json_hierarchical():
def test_output_json_hierarchical() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -400,7 +399,7 @@ def test_output_json_hierarchical():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_json_property_without_output_json():
def test_json_property_without_output_json() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -428,7 +427,7 @@ def test_json_property_without_output_json():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_json_dict_sequential():
def test_output_json_dict_sequential() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -448,12 +447,12 @@ def test_output_json_dict_sequential():
crew = Crew(agents=[scorer], tasks=[task], process=Process.sequential)
result = crew.kickoff()
assert {"score": 4} == result.json_dict
assert result.json_dict == {"score": 4}
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_json_dict_hierarchical():
def test_output_json_dict_hierarchical() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -478,12 +477,12 @@ def test_output_json_dict_hierarchical():
manager_llm="gpt-4o",
)
result = crew.kickoff()
assert {"score": 4} == result.json_dict
assert result.json_dict == {"score": 4}
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_pydantic_to_another_task():
def test_output_pydantic_to_another_task() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -515,13 +514,13 @@ def test_output_pydantic_to_another_task():
result = crew.kickoff()
pydantic_result = result.pydantic
assert isinstance(
pydantic_result, ScoreOutput
pydantic_result, ScoreOutput,
), "Expected pydantic result to be of type ScoreOutput"
assert pydantic_result.score == 5
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_json_to_another_task():
def test_output_json_to_another_task() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -548,11 +547,11 @@ def test_output_json_to_another_task():
crew = Crew(agents=[scorer], tasks=[task1, task2])
result = crew.kickoff()
assert '{"score": 4}' == result.json
assert result.json == '{"score": 4}'
@pytest.mark.vcr(filter_headers=["authorization"])
def test_save_task_output():
def test_save_task_output() -> None:
scorer = Agent(
role="Scorer",
goal="Score the title",
@@ -576,7 +575,7 @@ def test_save_task_output():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_save_task_json_output():
def test_save_task_json_output() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -600,13 +599,13 @@ def test_save_task_json_output():
output_file_exists = os.path.exists("score.json")
assert output_file_exists
assert {"score": 4} == json.loads(open("score.json").read())
assert json.loads(open("score.json").read()) == {"score": 4}
if output_file_exists:
os.remove("score.json")
@pytest.mark.vcr(filter_headers=["authorization"])
def test_save_task_pydantic_output():
def test_save_task_pydantic_output() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -630,13 +629,13 @@ def test_save_task_pydantic_output():
output_file_exists = os.path.exists("score.json")
assert output_file_exists
assert {"score": 4} == json.loads(open("score.json").read())
assert json.loads(open("score.json").read()) == {"score": 4}
if output_file_exists:
os.remove("score.json")
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_converter_cls():
def test_custom_converter_cls() -> None:
class ScoreOutput(BaseModel):
score: int
@@ -661,14 +660,14 @@ def test_custom_converter_cls():
crew = Crew(agents=[scorer], tasks=[task])
with patch.object(
ScoreConverter, "to_pydantic", return_value=ScoreOutput(score=5)
ScoreConverter, "to_pydantic", return_value=ScoreOutput(score=5),
) as mock_to_pydantic:
crew.kickoff()
mock_to_pydantic.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_increment_delegations_for_hierarchical_process():
def test_increment_delegations_for_hierarchical_process() -> None:
scorer = Agent(
role="Scorer",
goal="Score the title",
@@ -695,7 +694,7 @@ def test_increment_delegations_for_hierarchical_process():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_increment_delegations_for_sequential_process():
def test_increment_delegations_for_sequential_process() -> None:
manager = Agent(
role="Manager",
goal="Coordinate scoring processes",
@@ -729,13 +728,14 @@ def test_increment_delegations_for_sequential_process():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_increment_tool_errors():
def test_increment_tool_errors() -> None:
from crewai.tools import tool
@tool
def scoring_examples() -> None:
"Useful examples for scoring titles."
raise Exception("Error")
"""Useful examples for scoring titles."""
msg = "Error"
raise Exception(msg)
scorer = Agent(
role="Scorer",
@@ -762,7 +762,7 @@ def test_increment_tool_errors():
assert len(increment_tools_errors.mock_calls) > 0
def test_task_definition_based_on_dict():
def test_task_definition_based_on_dict() -> None:
config = {
"description": "Give me an integer score between 1-5 for the following title: 'The impact of AI in the future of work', check examples to based your evaluation.",
"expected_output": "The score of the title.",
@@ -775,7 +775,7 @@ def test_task_definition_based_on_dict():
assert task.agent is None
def test_conditional_task_definition_based_on_dict():
def test_conditional_task_definition_based_on_dict() -> None:
config = {
"description": "Give me an integer score between 1-5 for the following title: 'The impact of AI in the future of work', check examples to based your evaluation.",
"expected_output": "The score of the title.",
@@ -788,7 +788,7 @@ def test_conditional_task_definition_based_on_dict():
assert task.agent is None
def test_conditional_task_copy_preserves_type():
def test_conditional_task_copy_preserves_type() -> None:
task_config = {
"description": "Give me an integer score between 1-5 for the following title: 'The impact of AI in the future of work', check examples to based your evaluation.",
"expected_output": "The score of the title.",
@@ -807,7 +807,7 @@ def test_conditional_task_copy_preserves_type():
assert isinstance(copied_conditional_task, ConditionalTask)
def test_interpolate_inputs():
def test_interpolate_inputs() -> None:
task = Task(
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
@@ -815,7 +815,7 @@ def test_interpolate_inputs():
)
task.interpolate_inputs_and_add_conversation_history(
inputs={"topic": "AI", "date": "2025"}
inputs={"topic": "AI", "date": "2025"},
)
assert (
task.description
@@ -825,7 +825,7 @@ def test_interpolate_inputs():
assert task.output_file == "/tmp/AI/output_2025.txt"
task.interpolate_inputs_and_add_conversation_history(
inputs={"topic": "ML", "date": "2025"}
inputs={"topic": "ML", "date": "2025"},
)
assert (
task.description
@@ -835,10 +835,10 @@ def test_interpolate_inputs():
assert task.output_file == "/tmp/ML/output_2025.txt"
def test_interpolate_only():
def test_interpolate_only() -> None:
"""Test the interpolate_only method for various scenarios including JSON structure preservation."""
task = Task(
description="Unused in this test", expected_output="Unused in this test"
Task(
description="Unused in this test", expected_output="Unused in this test",
)
# Test JSON structure preservation
@@ -855,7 +855,7 @@ def test_interpolate_only():
# Test normal string interpolation
normal_string = "Hello {name}, welcome to {place}!"
result = interpolate_only(
input_string=normal_string, inputs={"name": "John", "place": "CrewAI"}
input_string=normal_string, inputs={"name": "John", "place": "CrewAI"},
)
assert result == "Hello John, welcome to CrewAI!"
@@ -869,9 +869,9 @@ def test_interpolate_only():
assert result == no_placeholders
def test_interpolate_only_with_dict_inside_expected_output():
def test_interpolate_only_with_dict_inside_expected_output() -> None:
"""Test the interpolate_only method for various scenarios including JSON structure preservation."""
task = Task(
Task(
description="Unused in this test",
expected_output="Unused in this test: {questions}",
)
@@ -883,7 +883,7 @@ def test_interpolate_only_with_dict_inside_expected_output():
"questions": {
"main_question": "What is the user's name?",
"secondary_question": "What is the user's age?",
}
},
},
)
assert '"main_question": "What is the user\'s name?"' in result
@@ -892,7 +892,7 @@ def test_interpolate_only_with_dict_inside_expected_output():
normal_string = "Hello {name}, welcome to {place}!"
result = interpolate_only(
input_string=normal_string, inputs={"name": "John", "place": "CrewAI"}
input_string=normal_string, inputs={"name": "John", "place": "CrewAI"},
)
assert result == "Hello John, welcome to CrewAI!"
@@ -904,7 +904,7 @@ def test_interpolate_only_with_dict_inside_expected_output():
assert result == no_placeholders
def test_task_output_str_with_pydantic():
def test_task_output_str_with_pydantic() -> None:
from crewai.tasks.output_format import OutputFormat
class ScoreOutput(BaseModel):
@@ -921,7 +921,7 @@ def test_task_output_str_with_pydantic():
assert str(task_output) == str(score_output)
def test_task_output_str_with_json_dict():
def test_task_output_str_with_json_dict() -> None:
from crewai.tasks.output_format import OutputFormat
json_dict = {"score": 4}
@@ -935,7 +935,7 @@ def test_task_output_str_with_json_dict():
assert str(task_output) == str(json_dict)
def test_task_output_str_with_raw():
def test_task_output_str_with_raw() -> None:
from crewai.tasks.output_format import OutputFormat
raw_output = "Raw task output"
@@ -949,7 +949,7 @@ def test_task_output_str_with_raw():
assert str(task_output) == raw_output
def test_task_output_str_with_pydantic_and_json_dict():
def test_task_output_str_with_pydantic_and_json_dict() -> None:
from crewai.tasks.output_format import OutputFormat
class ScoreOutput(BaseModel):
@@ -969,7 +969,7 @@ def test_task_output_str_with_pydantic_and_json_dict():
assert str(task_output) == str(score_output)
def test_task_output_str_with_none():
def test_task_output_str_with_none() -> None:
from crewai.tasks.output_format import OutputFormat
task_output = TaskOutput(
@@ -981,7 +981,7 @@ def test_task_output_str_with_none():
assert str(task_output) == ""
def test_key():
def test_key() -> None:
original_description = "Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting."
original_expected_output = "Bullet point list of 5 interesting ideas about {topic}."
task = Task(
@@ -989,7 +989,7 @@ def test_key():
expected_output=original_expected_output,
)
hash = hashlib.md5(
f"{original_description}|{original_expected_output}".encode()
f"{original_description}|{original_expected_output}".encode(),
).hexdigest()
assert task.key == hash, "The key should be the hash of the description."
@@ -1000,7 +1000,7 @@ def test_key():
), "The key should be the hash of the non-interpolated description."
def test_output_file_validation():
def test_output_file_validation() -> None:
"""Test output file path validation."""
# Valid paths
assert (
@@ -1068,7 +1068,7 @@ def test_output_file_validation():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_execution_times():
def test_task_execution_times() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
@@ -1093,8 +1093,8 @@ def test_task_execution_times():
assert task.execution_duration == (task.end_time - task.start_time).total_seconds()
def test_interpolate_with_list_of_strings():
task = Task(
def test_interpolate_with_list_of_strings() -> None:
Task(
description="Test list interpolation",
expected_output="List: {items}",
)
@@ -1111,8 +1111,8 @@ def test_interpolate_with_list_of_strings():
assert result == "Available items: []"
def test_interpolate_with_list_of_dicts():
task = Task(
def test_interpolate_with_list_of_dicts() -> None:
Task(
description="Test list of dicts interpolation",
expected_output="People: {people}",
)
@@ -1121,7 +1121,7 @@ def test_interpolate_with_list_of_dicts():
"people": [
{"name": "Alice", "age": 30, "skills": ["Python", "AI"]},
{"name": "Bob", "age": 25, "skills": ["Java", "Cloud"]},
]
],
}
result = interpolate_only("{people}", input_data)
@@ -1136,8 +1136,8 @@ def test_interpolate_with_list_of_dicts():
assert parsed_result[1]["skills"] == ["Java", "Cloud"]
def test_interpolate_with_nested_structures():
task = Task(
def test_interpolate_with_nested_structures() -> None:
Task(
description="Test nested structures",
expected_output="Company: {company}",
)
@@ -1153,7 +1153,7 @@ def test_interpolate_with_nested_structures():
},
{"name": "Sales", "employees": 20, "regions": {"north": 5, "south": 3}},
],
}
},
}
result = interpolate_only("{company}", input_data)
parsed = eval(result)
@@ -1164,8 +1164,8 @@ def test_interpolate_with_nested_structures():
assert parsed["departments"][1]["regions"]["north"] == 5
def test_interpolate_with_special_characters():
task = Task(
def test_interpolate_with_special_characters() -> None:
Task(
description="Test special characters in dicts",
expected_output="Data: {special_data}",
)
@@ -1176,7 +1176,7 @@ def test_interpolate_with_special_characters():
"unicode": "文字化けテスト",
"symbols": "!@#$%^&*()",
"empty": "",
}
},
}
result = interpolate_only("{special_data}", input_data)
parsed = eval(result)
@@ -1187,8 +1187,8 @@ def test_interpolate_with_special_characters():
assert parsed["empty"] == ""
def test_interpolate_mixed_types():
task = Task(
def test_interpolate_mixed_types() -> None:
Task(
description="Test mixed type interpolation",
expected_output="Mixed: {data}",
)
@@ -1203,7 +1203,7 @@ def test_interpolate_mixed_types():
"validated": True,
"tags": ["demo", "test", "temp"],
},
}
},
}
result = interpolate_only("{data}", input_data)
parsed = eval(result)
@@ -1213,8 +1213,8 @@ def test_interpolate_mixed_types():
assert parsed["metadata"]["tags"] == ["demo", "test", "temp"]
def test_interpolate_complex_combination():
task = Task(
def test_interpolate_complex_combination() -> None:
Task(
description="Test complex combination",
expected_output="Report: {report}",
)
@@ -1231,7 +1231,7 @@ def test_interpolate_complex_combination():
"metrics": {"sales": 18000, "expenses": 8500, "profit": 9500},
"top_products": ["Product C", "Product D"],
},
]
],
}
result = interpolate_only("{report}", input_data)
parsed = eval(result)
@@ -1242,8 +1242,8 @@ def test_interpolate_complex_combination():
assert "Product D" in parsed[1]["top_products"]
def test_interpolate_invalid_type_validation():
task = Task(
def test_interpolate_invalid_type_validation() -> None:
Task(
description="Test invalid type validation",
expected_output="Should never reach here",
)
@@ -1260,24 +1260,24 @@ def test_interpolate_invalid_type_validation():
"name": "John",
"age": 30,
"tags": {"a", "b", "c"}, # Set is invalid
}
},
}
with pytest.raises(ValueError) as excinfo:
interpolate_only("{data}", {"data": invalid_nested})
assert "Unsupported type set" in str(excinfo.value)
def test_interpolate_custom_object_validation():
task = Task(
def test_interpolate_custom_object_validation() -> None:
Task(
description="Test custom object rejection",
expected_output="Should never reach here",
)
class CustomObject:
def __init__(self, value):
def __init__(self, value) -> None:
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
# Test with custom object at top level
@@ -1298,13 +1298,13 @@ def test_interpolate_custom_object_validation():
# Test with deeply nested custom object
with pytest.raises(ValueError) as excinfo:
interpolate_only(
"{data}", {"data": {"level1": {"level2": [{"level3": CustomObject(5)}]}}}
"{data}", {"data": {"level1": {"level2": [{"level3": CustomObject(5)}]}}},
)
assert "Unsupported type CustomObject" in str(excinfo.value)
def test_interpolate_valid_complex_types():
task = Task(
def test_interpolate_valid_complex_types() -> None:
Task(
description="Test valid complex types",
expected_output="Validation should pass",
)
@@ -1327,8 +1327,8 @@ def test_interpolate_valid_complex_types():
assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5
def test_interpolate_edge_cases():
task = Task(
def test_interpolate_edge_cases() -> None:
Task(
description="Test edge cases",
expected_output="Edge case handling",
)
@@ -1346,8 +1346,8 @@ def test_interpolate_edge_cases():
assert interpolate_only("{flag}", {"flag": False}) == "False"
def test_interpolate_valid_types():
task = Task(
def test_interpolate_valid_types() -> None:
Task(
description="Test valid types including null and boolean",
expected_output="Should pass validation",
)
@@ -1371,13 +1371,13 @@ def test_interpolate_valid_types():
assert parsed["nested"]["empty"] is None
def test_task_with_no_max_execution_time():
def test_task_with_no_max_execution_time() -> None:
researcher = Agent(
role="Researcher",
goal="Make the best research and analysis on content about AI and AI agents",
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
allow_delegation=False,
max_execution_time=None
max_execution_time=None,
)
task = Task(
@@ -1393,13 +1393,13 @@ def test_task_with_no_max_execution_time():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_with_max_execution_time():
def test_task_with_max_execution_time() -> None:
from crewai.tools import tool
"""Test that execution raises TimeoutError when max_execution_time is exceeded."""
@tool("what amazing tool", result_as_answer=True)
def my_tool() -> str:
"My tool"
"""My tool."""
time.sleep(1)
return "okay"
@@ -1412,7 +1412,7 @@ def test_task_with_max_execution_time():
),
allow_delegation=False,
tools=[my_tool],
max_execution_time=4
max_execution_time=4,
)
task = Task(
@@ -1426,13 +1426,13 @@ def test_task_with_max_execution_time():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_with_max_execution_time_exceeded():
def test_task_with_max_execution_time_exceeded() -> None:
from crewai.tools import tool
"""Test that execution raises TimeoutError when max_execution_time is exceeded."""
@tool("what amazing tool", result_as_answer=True)
def my_tool() -> str:
"My tool"
"""My tool."""
time.sleep(10)
return "okay"
@@ -1445,7 +1445,7 @@ def test_task_with_max_execution_time_exceeded():
),
allow_delegation=False,
tools=[my_tool],
max_execution_time=1
max_execution_time=1,
)
task = Task(
@@ -1455,4 +1455,4 @@ def test_task_with_max_execution_time_exceeded():
)
with pytest.raises(TimeoutError):
task.execute_sync(agent=researcher)
task.execute_sync(agent=researcher)

View File

@@ -8,7 +8,7 @@ from crewai.telemetry import Telemetry
@pytest.mark.parametrize(
"env_var,value,expected_ready",
("env_var", "value", "expected_ready"),
[
("OTEL_SDK_DISABLED", "true", False),
("OTEL_SDK_DISABLED", "TRUE", False),
@@ -18,7 +18,7 @@ from crewai.telemetry import Telemetry
("CREWAI_DISABLE_TELEMETRY", "false", True),
],
)
def test_telemetry_environment_variables(env_var, value, expected_ready):
def test_telemetry_environment_variables(env_var, value, expected_ready) -> None:
"""Test telemetry state with different environment variable configurations."""
with patch.dict(os.environ, {env_var: value}):
with patch("crewai.telemetry.telemetry.TracerProvider"):
@@ -26,7 +26,7 @@ def test_telemetry_environment_variables(env_var, value, expected_ready):
assert telemetry.ready is expected_ready
def test_telemetry_enabled_by_default():
def test_telemetry_enabled_by_default() -> None:
"""Test that telemetry is enabled by default."""
with patch.dict(os.environ, {}, clear=True):
with patch("crewai.telemetry.telemetry.TracerProvider"):
@@ -43,7 +43,7 @@ from opentelemetry import trace
side_effect=Exception("Test exception"),
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock):
def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock) -> None:
error = Exception("Test exception")
export_mock.side_effect = error

View File

@@ -6,7 +6,7 @@ import pytest
from crewai.telemetry import Telemetry
@pytest.mark.parametrize("env_var,value,expected_ready", [
@pytest.mark.parametrize(("env_var", "value", "expected_ready"), [
("OTEL_SDK_DISABLED", "true", False),
("OTEL_SDK_DISABLED", "TRUE", False),
("CREWAI_DISABLE_TELEMETRY", "true", False),
@@ -14,7 +14,7 @@ from crewai.telemetry import Telemetry
("OTEL_SDK_DISABLED", "false", True),
("CREWAI_DISABLE_TELEMETRY", "false", True),
])
def test_telemetry_environment_variables(env_var, value, expected_ready):
def test_telemetry_environment_variables(env_var, value, expected_ready) -> None:
"""Test telemetry state with different environment variable configurations."""
with patch.dict(os.environ, {env_var: value}):
with patch("crewai.telemetry.telemetry.TracerProvider"):
@@ -22,7 +22,7 @@ def test_telemetry_environment_variables(env_var, value, expected_ready):
assert telemetry.ready is expected_ready
def test_telemetry_enabled_by_default():
def test_telemetry_enabled_by_default() -> None:
"""Test that telemetry is enabled by default."""
with patch.dict(os.environ, {}, clear=True):
with patch("crewai.telemetry.telemetry.TracerProvider"):

View File

@@ -6,12 +6,13 @@ from crewai.flow.persistence import persist
class PoemState(FlowState):
"""Test state model with default values that should be overridden."""
sentence_count: int = 1000 # Default that should be overridden
has_set_count: bool = False # Track whether we've set the count
poem_type: str = ""
def test_default_value_override():
def test_default_value_override() -> None:
"""Test that persisted state values override class defaults."""
@persist()
@@ -19,7 +20,7 @@ def test_default_value_override():
initial_state = PoemState
@start()
def set_sentence_count(self):
def set_sentence_count(self) -> None:
if self.state.has_set_count and self.state.sentence_count == 2:
self.state.sentence_count = 3
@@ -59,7 +60,7 @@ def test_default_value_override():
assert flow4.state.sentence_count == 1000 # Should load 1000, not 2
def test_multi_step_default_override():
def test_multi_step_default_override() -> None:
"""Test default value override with multiple start methods."""
@persist()
@@ -67,15 +68,13 @@ def test_multi_step_default_override():
initial_state = PoemState
@start()
def set_sentence_count(self):
print("Setting sentence count")
def set_sentence_count(self) -> None:
if not self.state.has_set_count:
self.state.sentence_count = 3
self.state.has_set_count = True
@listen(set_sentence_count)
def set_poem_type(self):
print("Setting poem type")
def set_poem_type(self) -> None:
if self.state.sentence_count == 3:
self.state.poem_type = "haiku"
elif self.state.sentence_count == 5:
@@ -84,8 +83,8 @@ def test_multi_step_default_override():
self.state.poem_type = "free_verse"
@listen(set_poem_type)
def finished(self):
print("finished")
def finished(self) -> None:
pass
# First run - should set both sentence count and poem type
flow1 = MultiStepPoemFlow()
@@ -98,7 +97,7 @@ def test_multi_step_default_override():
flow2 = MultiStepPoemFlow()
flow2.kickoff(inputs={
"id": original_uuid,
"sentence_count": 5
"sentence_count": 5,
})
assert flow2.state.sentence_count == 5
assert flow2.state.poem_type == "limerick"
@@ -106,7 +105,7 @@ def test_multi_step_default_override():
# Third run - new flow without persisted state should use defaults
flow3 = MultiStepPoemFlow()
flow3.kickoff(inputs={
"id": original_uuid
"id": original_uuid,
})
assert flow3.state.sentence_count == 5
assert flow3.state.poem_type == "limerick"
assert flow3.state.poem_type == "limerick"

View File

@@ -1,10 +1,6 @@
"""Test flow state persistence functionality."""
import os
from typing import Dict
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
@@ -18,17 +14,17 @@ class TestState(FlowState):
message: str = ""
def test_persist_decorator_saves_state(tmp_path, caplog):
def test_persist_decorator_saves_state(tmp_path, caplog) -> None:
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class TestFlow(Flow[Dict[str, str]]):
initial_state = dict() # Use dict instance as initial state
class TestFlow(Flow[dict[str, str]]):
initial_state = {} # Use dict instance as initial state
@start()
@persist(persistence)
def init_step(self):
def init_step(self) -> None:
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
@@ -42,7 +38,7 @@ def test_persist_decorator_saves_state(tmp_path, caplog):
assert saved_state["message"] == "Hello, World!"
def test_structured_state_persistence(tmp_path):
def test_structured_state_persistence(tmp_path) -> None:
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
@@ -52,7 +48,7 @@ def test_structured_state_persistence(tmp_path):
@start()
@persist(persistence)
def count_up(self):
def count_up(self) -> None:
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"
@@ -67,7 +63,7 @@ def test_structured_state_persistence(tmp_path):
assert saved_state["message"] == "Count is 1"
def test_flow_state_restoration(tmp_path):
def test_flow_state_restoration(tmp_path) -> None:
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
@@ -76,7 +72,7 @@ def test_flow_state_restoration(tmp_path):
class RestorableFlow(Flow[TestState]):
@start()
@persist(persistence)
def set_message(self):
def set_message(self) -> None:
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
@@ -106,7 +102,7 @@ def test_flow_state_restoration(tmp_path):
assert flow3.state.message == "Updated message" # Overridden
def test_multiple_method_persistence(tmp_path):
def test_multiple_method_persistence(tmp_path) -> None:
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
@@ -114,7 +110,7 @@ def test_multiple_method_persistence(tmp_path):
class MultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
def step_1(self) -> None:
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
@@ -124,7 +120,7 @@ def test_multiple_method_persistence(tmp_path):
@listen(step_1)
@persist(persistence)
def step_2(self):
def step_2(self) -> None:
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
@@ -144,7 +140,7 @@ def test_multiple_method_persistence(tmp_path):
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
def step_1(self) -> None:
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
@@ -153,7 +149,7 @@ def test_multiple_method_persistence(tmp_path):
self.state.message = "Step 1"
@listen(step_1)
def step_2(self):
def step_2(self) -> None:
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
@@ -170,7 +166,7 @@ def test_multiple_method_persistence(tmp_path):
assert final_state.message == "Step 99999"
def test_persist_decorator_verbose_logging(tmp_path, caplog):
def test_persist_decorator_verbose_logging(tmp_path, caplog) -> None:
"""Test that @persist decorator's verbose parameter controls logging."""
# Set logging level to ensure we capture all logs
caplog.set_level("INFO")
@@ -179,12 +175,12 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog):
persistence = SQLiteFlowPersistence(db_path)
# Test with verbose=False (default)
class QuietFlow(Flow[Dict[str, str]]):
initial_state = dict()
class QuietFlow(Flow[dict[str, str]]):
initial_state = {}
@start()
@persist(persistence) # Default verbose=False
def init_step(self):
def init_step(self) -> None:
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-1"
@@ -196,12 +192,12 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog):
caplog.clear()
# Test with verbose=True
class VerboseFlow(Flow[Dict[str, str]]):
initial_state = dict()
class VerboseFlow(Flow[dict[str, str]]):
initial_state = {}
@start()
@persist(persistence, verbose=True)
def init_step(self):
def init_step(self) -> None:
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-2"

View File

@@ -1,4 +1,3 @@
import asyncio
from typing import cast
from unittest.mock import Mock
@@ -35,10 +34,9 @@ class WebSearchTool(BaseTool):
# This is a mock implementation
if "tokyo" in query.lower():
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
elif "climate change" in query.lower() and "coral" in query.lower():
if "climate change" in query.lower() and "coral" in query.lower():
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
else:
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
# Define Mock Calculator Tool
@@ -54,7 +52,7 @@ class CalculatorTool(BaseTool):
result = eval(expression, {"__builtins__": {}})
return f"The result of {expression} is {result}"
except Exception as e:
return f"Error calculating {expression}: {str(e)}"
return f"Error calculating {expression}: {e!s}"
# Define a custom response format using Pydantic
@@ -68,7 +66,7 @@ class ResearchResult(BaseModel):
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.parametrize("verbose", [True, False])
def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose) -> None:
"""Test that LiteAgent is created with the correct parameters when Agent.kickoff() is called."""
# Create a test agent with specific parameters
llm = LLM(model="gpt-4o-mini")
@@ -93,7 +91,7 @@ def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
# Define a mock LiteAgent class that captures its arguments
class MockLiteAgent(original_lite_agent):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
nonlocal created_lite_agent
created_lite_agent = kwargs
super().__init__(**kwargs)
@@ -129,7 +127,7 @@ def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_lite_agent_with_tools():
def test_lite_agent_with_tools() -> None:
"""Test that Agent can use tools."""
# Create a LiteAgent with tools
llm = LLM(model="gpt-4o-mini")
@@ -143,7 +141,7 @@ def test_lite_agent_with_tools():
)
result = agent.kickoff(
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?",
)
assert (
@@ -156,7 +154,7 @@ def test_lite_agent_with_tools():
received_events = []
@crewai_event_bus.on(ToolUsageStartedEvent)
def event_handler(source, event):
def event_handler(source, event) -> None:
received_events.append(event)
agent.kickoff("What are the effects of climate change on coral reefs?")
@@ -196,13 +194,10 @@ def test_lite_agent_structured_output():
response_format=SimpleOutput,
)
print(f"\n=== Agent Result Type: {type(result)}")
print(f"=== Agent Result: {result}")
print(f"=== Pydantic: {result.pydantic}")
assert result.pydantic is not None, "Should return a Pydantic model"
output = cast(SimpleOutput, result.pydantic)
output = cast("SimpleOutput", result.pydantic)
assert isinstance(output.summary, str), "Summary should be a string"
assert len(output.summary) > 0, "Summary should not be empty"
@@ -217,7 +212,7 @@ def test_lite_agent_structured_output():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_lite_agent_returns_usage_metrics():
def test_lite_agent_returns_usage_metrics() -> None:
"""Test that LiteAgent returns usage metrics."""
llm = LLM(model="gpt-4o-mini")
agent = Agent(
@@ -230,7 +225,7 @@ def test_lite_agent_returns_usage_metrics():
)
result = agent.kickoff(
"What is the population of Tokyo? Return your strucutred output in JSON format with the following fields: summary, confidence"
"What is the population of Tokyo? Return your strucutred output in JSON format with the following fields: summary, confidence",
)
assert result.usage_metrics is not None
@@ -239,7 +234,7 @@ def test_lite_agent_returns_usage_metrics():
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.asyncio
async def test_lite_agent_returns_usage_metrics_async():
async def test_lite_agent_returns_usage_metrics_async() -> None:
"""Test that LiteAgent returns usage metrics when run asynchronously."""
llm = LLM(model="gpt-4o-mini")
agent = Agent(
@@ -252,7 +247,7 @@ async def test_lite_agent_returns_usage_metrics_async():
)
result = await agent.kickoff_async(
"What is the population of Tokyo? Return your strucutred output in JSON format with the following fields: summary, confidence"
"What is the population of Tokyo? Return your strucutred output in JSON format with the following fields: summary, confidence",
)
assert isinstance(result, LiteAgentOutput)
assert "21 million" in result.raw or "37 million" in result.raw
@@ -263,7 +258,7 @@ async def test_lite_agent_returns_usage_metrics_async():
class TestFlow(Flow):
"""A test flow that creates and runs an agent."""
def __init__(self, llm, tools):
def __init__(self, llm, tools) -> None:
self.llm = llm
self.tools = tools
super().__init__()
@@ -280,14 +275,14 @@ class TestFlow(Flow):
return agent.kickoff("Test query")
def verify_agent_parent_flow(result, agent, flow):
def verify_agent_parent_flow(result, agent, flow) -> None:
"""Verify that both the result and agent have the correct parent flow."""
assert result.parent_flow is flow
assert agent is not None
assert agent.parent_flow is flow
def test_sets_parent_flow_when_inside_flow():
def test_sets_parent_flow_when_inside_flow() -> None:
captured_agent = None
mock_llm = Mock(spec=LLM)
@@ -309,9 +304,9 @@ def test_sets_parent_flow_when_inside_flow():
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def capture_agent(source, event):
def capture_agent(source, event) -> None:
nonlocal captured_agent
captured_agent = source
result = flow.kickoff()
flow.kickoff()
assert captured_agent.parent_flow is flow

View File

@@ -6,9 +6,8 @@ from crewai import LLM, Agent, Crew, Task
@pytest.mark.skip(reason="Only run manually with valid API keys")
def test_multimodal_agent_with_image_url():
"""
Test that a multimodal agent can process images without validation errors.
def test_multimodal_agent_with_image_url() -> None:
"""Test that a multimodal agent can process images without validation errors.
This test reproduces the scenario from issue #2475.
"""
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -18,7 +17,7 @@ def test_multimodal_agent_with_image_url():
llm = LLM(
model="openai/gpt-4o", # model with vision capabilities
api_key=OPENAI_API_KEY,
temperature=0.7
temperature=0.7,
)
expert_analyst = Agent(
@@ -28,7 +27,7 @@ def test_multimodal_agent_with_image_url():
llm=llm,
verbose=True,
allow_delegation=False,
multimodal=True
multimodal=True,
)
inspection_task = Task(
@@ -40,7 +39,7 @@ def test_multimodal_agent_with_image_url():
Provide a detailed report highlighting any issues found.
""",
expected_output="A detailed report highlighting any issues found",
agent=expert_analyst
agent=expert_analyst,
)
crew = Crew(agents=[expert_analyst], tasks=[inspection_task])
Crew(agents=[expert_analyst], tasks=[inspection_task])

View File

@@ -1,4 +1,4 @@
from unittest.mock import ANY, Mock, patch
from unittest.mock import Mock, patch
import pytest
@@ -13,7 +13,7 @@ from crewai.utilities.events import (
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
def test_task_without_guardrail():
def test_task_without_guardrail() -> None:
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
agent.role = "test_agent"
@@ -27,7 +27,7 @@ def test_task_without_guardrail():
assert result.raw == "test result"
def test_task_with_successful_guardrail_func():
def test_task_with_successful_guardrail_func() -> None:
"""Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput):
@@ -45,7 +45,7 @@ def test_task_with_successful_guardrail_func():
assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail():
def test_task_with_failing_guardrail() -> None:
"""Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput):
@@ -72,7 +72,7 @@ def test_task_with_failing_guardrail():
assert task.retry_count == 1
def test_task_with_guardrail_retries():
def test_task_with_guardrail_retries() -> None:
"""Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput):
@@ -98,7 +98,7 @@ def test_task_with_guardrail_retries():
assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context():
def test_guardrail_error_in_context() -> None:
"""Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput):
@@ -118,7 +118,7 @@ def test_guardrail_error_in_context():
# Mock execute_task to succeed on second attempt
first_call = True
def execute_task(task, context, tools):
def execute_task(task, context, tools) -> str:
nonlocal first_call
if first_call:
first_call = False
@@ -152,9 +152,9 @@ def task_output():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_guardrail_process_output(task_output):
def test_task_guardrail_process_output(task_output) -> None:
guardrail = LLMGuardrail(
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o")
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o"),
)
result = guardrail(task_output)
@@ -163,7 +163,7 @@ def test_task_guardrail_process_output(task_output):
assert "exceeding the guardrail limit of fewer than" in result[1].lower()
guardrail = LLMGuardrail(
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o")
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o"),
)
result = guardrail(task_output)
@@ -172,27 +172,27 @@ def test_task_guardrail_process_output(task_output):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_guardrail_emits_events(sample_agent):
def test_guardrail_emits_events(sample_agent) -> None:
started_guardrail = []
completed_guardrail = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
def handle_guardrail_started(source, event) -> None:
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
{"guardrail": event.guardrail, "retry_count": event.retry_count},
)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
def handle_guardrail_completed(source, event) -> None:
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
},
)
task = Task(
@@ -248,7 +248,7 @@ def test_guardrail_emits_events(sample_agent):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_guardrail_when_an_error_occurs(sample_agent, task_output):
def test_guardrail_when_an_error_occurs(sample_agent, task_output) -> None:
with (
patch(
"crewai.Agent.kickoff",

View File

@@ -24,7 +24,7 @@ def vcr_config(request) -> dict:
@pytest.mark.vcr(filter_headers=["authorization"])
def test_delegate_work():
def test_delegate_work() -> None:
result = delegate_tool.run(
coworker="researcher",
task="share your take on AI Agents",
@@ -38,7 +38,7 @@ def test_delegate_work():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_delegate_work_with_wrong_co_worker_variable():
def test_delegate_work_with_wrong_co_worker_variable() -> None:
result = delegate_tool.run(
coworker="researcher",
task="share your take on AI Agents",
@@ -52,7 +52,7 @@ def test_delegate_work_with_wrong_co_worker_variable():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_ask_question():
def test_ask_question() -> None:
result = ask_tool.run(
coworker="researcher",
question="do you hate AI Agents?",
@@ -66,7 +66,7 @@ def test_ask_question():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_ask_question_with_wrong_co_worker_variable():
def test_ask_question_with_wrong_co_worker_variable() -> None:
result = ask_tool.run(
coworker="researcher",
question="do you hate AI Agents?",
@@ -80,7 +80,7 @@ def test_ask_question_with_wrong_co_worker_variable():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_delegate_work_withwith_coworker_as_array():
def test_delegate_work_withwith_coworker_as_array() -> None:
result = delegate_tool.run(
coworker="[researcher]",
task="share your take on AI Agents",
@@ -94,7 +94,7 @@ def test_delegate_work_withwith_coworker_as_array():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_ask_question_with_coworker_as_array():
def test_ask_question_with_coworker_as_array() -> None:
result = ask_tool.run(
coworker="[researcher]",
question="do you hate AI Agents?",
@@ -107,7 +107,7 @@ def test_ask_question_with_coworker_as_array():
)
def test_delegate_work_to_wrong_agent():
def test_delegate_work_to_wrong_agent() -> None:
result = ask_tool.run(
coworker="writer",
question="share your take on AI Agents",
@@ -120,7 +120,7 @@ def test_delegate_work_to_wrong_agent():
)
def test_ask_question_to_wrong_agent():
def test_ask_question_to_wrong_agent() -> None:
result = ask_tool.run(
coworker="writer",
question="do you hate AI Agents?",

View File

@@ -1,13 +1,11 @@
import asyncio
import inspect
import unittest
from typing import Any, Callable, Dict, List
from collections.abc import Callable
from unittest.mock import patch
from crewai.tools import BaseTool, tool
def test_creating_a_tool_using_annotation():
def test_creating_a_tool_using_annotation() -> None:
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, your agent will need this information to use it."""
@@ -20,7 +18,7 @@ def test_creating_a_tool_using_annotation():
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
)
assert my_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
"question": {"title": "Question", "type": "string"},
}
assert (
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
@@ -34,7 +32,7 @@ def test_creating_a_tool_using_annotation():
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
)
assert converted_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
"question": {"title": "Question", "type": "string"},
}
assert (
converted_tool.func("What is the meaning of life?")
@@ -42,7 +40,7 @@ def test_creating_a_tool_using_annotation():
)
def test_creating_a_tool_using_baseclass():
def test_creating_a_tool_using_baseclass() -> None:
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, your agent will need this information to use it."
@@ -59,7 +57,7 @@ def test_creating_a_tool_using_baseclass():
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
)
assert my_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
"question": {"title": "Question", "type": "string"},
}
assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?"
@@ -71,7 +69,7 @@ def test_creating_a_tool_using_baseclass():
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
)
assert converted_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
"question": {"title": "Question", "type": "string"},
}
assert (
converted_tool._run("What is the meaning of life?")
@@ -79,7 +77,7 @@ def test_creating_a_tool_using_baseclass():
)
def test_setting_cache_function():
def test_setting_cache_function() -> None:
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, your agent will need this information to use it."
@@ -93,7 +91,7 @@ def test_setting_cache_function():
assert not my_tool.cache_function()
def test_default_cache_function_is_true():
def test_default_cache_function_is_true() -> None:
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, your agent will need this information to use it."
@@ -106,30 +104,31 @@ def test_default_cache_function_is_true():
assert my_tool.cache_function()
def test_result_as_answer_in_tool_decorator():
def test_result_as_answer_in_tool_decorator() -> None:
@tool("Tool with result as answer", result_as_answer=True)
def my_tool_with_result_as_answer(question: str) -> str:
"""This tool will return its result as the final answer."""
return question
assert my_tool_with_result_as_answer.result_as_answer is True
converted_tool = my_tool_with_result_as_answer.to_structured_tool()
assert converted_tool.result_as_answer is True
@tool("Tool with default result_as_answer")
def my_tool_with_default(question: str) -> str:
"""This tool uses the default result_as_answer value."""
return question
assert my_tool_with_default.result_as_answer is False
converted_tool = my_tool_with_default.to_structured_tool()
assert converted_tool.result_as_answer is False
class SyncTool(BaseTool):
"""Test implementation with a synchronous _run method"""
"""Test implementation with a synchronous _run method."""
name: str = "sync_tool"
description: str = "A synchronous tool for testing"
@@ -139,7 +138,8 @@ class SyncTool(BaseTool):
class AsyncTool(BaseTool):
"""Test implementation with an asynchronous _run method"""
"""Test implementation with an asynchronous _run method."""
name: str = "async_tool"
description: str = "An asynchronous tool for testing"
@@ -149,7 +149,7 @@ class AsyncTool(BaseTool):
return f"Processed {input_text} asynchronously"
def test_sync_run_returns_direct_result():
def test_sync_run_returns_direct_result() -> None:
"""Test that _run in a synchronous tool returns a direct result, not a coroutine."""
tool = SyncTool()
result = tool._run(input_text="hello")
@@ -161,7 +161,7 @@ def test_sync_run_returns_direct_result():
assert run_result == "Processed hello synchronously"
def test_async_run_returns_coroutine():
def test_async_run_returns_coroutine() -> None:
"""Test that _run in an asynchronous tool returns a coroutine object."""
tool = AsyncTool()
result = tool._run(input_text="hello")
@@ -170,11 +170,11 @@ def test_async_run_returns_coroutine():
result.close() # Clean up the coroutine
def test_run_calls_asyncio_run_for_async_tools():
def test_run_calls_asyncio_run_for_async_tools() -> None:
"""Test that asyncio.run is called when using async tools."""
async_tool = AsyncTool()
with patch('asyncio.run') as mock_run:
with patch("asyncio.run") as mock_run:
mock_run.return_value = "Processed test asynchronously"
async_result = async_tool.run(input_text="test")
@@ -182,11 +182,11 @@ def test_run_calls_asyncio_run_for_async_tools():
assert async_result == "Processed test asynchronously"
def test_run_does_not_call_asyncio_run_for_sync_tools():
def test_run_does_not_call_asyncio_run_for_sync_tools() -> None:
"""Test that asyncio.run is NOT called when using sync tools."""
sync_tool = SyncTool()
with patch('asyncio.run') as mock_run:
with patch("asyncio.run") as mock_run:
sync_result = sync_tool.run(input_text="test")
mock_run.assert_not_called()

View File

@@ -1,4 +1,3 @@
from typing import Optional
import pytest
from pydantic import BaseModel, Field
@@ -26,8 +25,8 @@ def schema_class():
class InternalCrewStructuredTool:
def test_initialization(self, basic_function, schema_class):
"""Test basic initialization of CrewStructuredTool"""
def test_initialization(self, basic_function, schema_class) -> None:
"""Test basic initialization of CrewStructuredTool."""
tool = CrewStructuredTool(
name="test_tool",
description="Test tool description",
@@ -40,10 +39,10 @@ class InternalCrewStructuredTool:
assert tool.func == basic_function
assert tool.args_schema == schema_class
def test_from_function(self, basic_function):
"""Test creating tool from function"""
def test_from_function(self, basic_function) -> None:
"""Test creating tool from function."""
tool = CrewStructuredTool.from_function(
func=basic_function, name="test_tool", description="Test description"
func=basic_function, name="test_tool", description="Test description",
)
assert tool.name == "test_tool"
@@ -51,8 +50,8 @@ class InternalCrewStructuredTool:
assert tool.func == basic_function
assert isinstance(tool.args_schema, type(BaseModel))
def test_validate_function_signature(self, basic_function, schema_class):
"""Test function signature validation"""
def test_validate_function_signature(self, basic_function, schema_class) -> None:
"""Test function signature validation."""
tool = CrewStructuredTool(
name="test_tool",
description="Test tool",
@@ -64,44 +63,44 @@ class InternalCrewStructuredTool:
tool._validate_function_signature()
@pytest.mark.asyncio
async def test_ainvoke(self, basic_function):
"""Test asynchronous invocation"""
async def test_ainvoke(self, basic_function) -> None:
"""Test asynchronous invocation."""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
result = await tool.ainvoke(input={"param1": "test"})
assert result == "test 0"
def test_parse_args_dict(self, basic_function):
"""Test parsing dictionary arguments"""
def test_parse_args_dict(self, basic_function) -> None:
"""Test parsing dictionary arguments."""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
parsed = tool._parse_args({"param1": "test", "param2": 42})
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_parse_args_string(self, basic_function):
"""Test parsing string arguments"""
def test_parse_args_string(self, basic_function) -> None:
"""Test parsing string arguments."""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_complex_types(self):
"""Test handling of complex parameter types"""
def test_complex_types(self) -> None:
"""Test handling of complex parameter types."""
def complex_func(nested: dict, items: list) -> str:
"""Process complex types."""
return f"Processed {len(items)} items with {len(nested)} nested keys"
tool = CrewStructuredTool.from_function(
func=complex_func, name="test_tool", description="Test complex types"
func=complex_func, name="test_tool", description="Test complex types",
)
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
assert result == "Processed 3 items with 1 nested keys"
def test_schema_inheritance(self):
"""Test tool creation with inherited schema"""
def test_schema_inheritance(self) -> None:
"""Test tool creation with inherited schema."""
def extended_func(base_param: str, extra_param: int) -> str:
"""Test function with inherited schema."""
@@ -114,25 +113,25 @@ class InternalCrewStructuredTool:
extra_param: int
tool = CrewStructuredTool.from_function(
func=extended_func, name="test_tool", args_schema=ExtendedSchema
func=extended_func, name="test_tool", args_schema=ExtendedSchema,
)
result = tool.invoke({"base_param": "test", "extra_param": 42})
assert result == "test 42"
def test_default_values_in_schema(self):
"""Test handling of default values in schema"""
def test_default_values_in_schema(self) -> None:
"""Test handling of default values in schema."""
def default_func(
required_param: str,
optional_param: str = "default",
nullable_param: Optional[int] = None,
nullable_param: int | None = None,
) -> str:
"""Test function with default values."""
return f"{required_param} {optional_param} {nullable_param}"
tool = CrewStructuredTool.from_function(
func=default_func, name="test_tool", description="Test defaults"
func=default_func, name="test_tool", description="Test defaults",
)
# Test with minimal parameters
@@ -141,6 +140,6 @@ class InternalCrewStructuredTool:
# Test with all parameters
result = tool.invoke(
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
{"required_param": "test", "optional_param": "custom", "nullable_param": 42},
)
assert result == "test custom 42"

View File

@@ -20,10 +20,10 @@ from crewai.utilities.events.tool_usage_events import (
class RandomNumberToolInput(BaseModel):
min_value: int = Field(
..., description="The minimum value of the range (inclusive)"
..., description="The minimum value of the range (inclusive)",
)
max_value: int = Field(
..., description="The maximum value of the range (inclusive)"
..., description="The maximum value of the range (inclusive)",
)
@@ -52,19 +52,19 @@ example_task = Task(
)
def test_random_number_tool_range():
def test_random_number_tool_range() -> None:
tool = RandomNumberTool()
result = tool._run(1, 10)
assert 1 <= result <= 10
def test_random_number_tool_invalid_range():
def test_random_number_tool_invalid_range() -> None:
tool = RandomNumberTool()
with pytest.raises(ValueError):
tool._run(10, 1) # min_value > max_value
def test_random_number_tool_schema():
def test_random_number_tool_schema() -> None:
tool = RandomNumberTool()
# Get the schema using model_json_schema()
@@ -93,7 +93,7 @@ def test_random_number_tool_schema():
)
def test_tool_usage_render():
def test_tool_usage_render() -> None:
tool = RandomNumberTool()
tool_usage = ToolUsage(
@@ -128,7 +128,7 @@ def test_tool_usage_render():
)
def test_validate_tool_input_booleans_and_none():
def test_validate_tool_input_booleans_and_none() -> None:
# Create a ToolUsage instance with mocks
tool_usage = ToolUsage(
tools_handler=MagicMock(),
@@ -147,7 +147,7 @@ def test_validate_tool_input_booleans_and_none():
assert arguments == expected_arguments
def test_validate_tool_input_mixed_types():
def test_validate_tool_input_mixed_types() -> None:
# Create a ToolUsage instance with mocks
tool_usage = ToolUsage(
tools_handler=MagicMock(),
@@ -166,7 +166,7 @@ def test_validate_tool_input_mixed_types():
assert arguments == expected_arguments
def test_validate_tool_input_single_quotes():
def test_validate_tool_input_single_quotes() -> None:
# Create a ToolUsage instance with mocks
tool_usage = ToolUsage(
tools_handler=MagicMock(),
@@ -185,7 +185,7 @@ def test_validate_tool_input_single_quotes():
assert arguments == expected_arguments
def test_validate_tool_input_invalid_json_repairable():
def test_validate_tool_input_invalid_json_repairable() -> None:
# Create a ToolUsage instance with mocks
tool_usage = ToolUsage(
tools_handler=MagicMock(),
@@ -204,7 +204,7 @@ def test_validate_tool_input_invalid_json_repairable():
assert arguments == expected_arguments
def test_validate_tool_input_with_special_characters():
def test_validate_tool_input_with_special_characters() -> None:
# Create a ToolUsage instance with mocks
tool_usage = ToolUsage(
tools_handler=MagicMock(),
@@ -223,7 +223,7 @@ def test_validate_tool_input_with_special_characters():
assert arguments == expected_arguments
def test_validate_tool_input_none_input():
def test_validate_tool_input_none_input() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -237,7 +237,7 @@ def test_validate_tool_input_none_input():
assert arguments == {}
def test_validate_tool_input_valid_json():
def test_validate_tool_input_valid_json() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -254,7 +254,7 @@ def test_validate_tool_input_valid_json():
assert arguments == expected_arguments
def test_validate_tool_input_python_dict():
def test_validate_tool_input_python_dict() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -271,7 +271,7 @@ def test_validate_tool_input_python_dict():
assert arguments == expected_arguments
def test_validate_tool_input_json5_unquoted_keys():
def test_validate_tool_input_json5_unquoted_keys() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -288,7 +288,7 @@ def test_validate_tool_input_json5_unquoted_keys():
assert arguments == expected_arguments
def test_validate_tool_input_with_trailing_commas():
def test_validate_tool_input_with_trailing_commas() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -305,7 +305,7 @@ def test_validate_tool_input_with_trailing_commas():
assert arguments == expected_arguments
def test_validate_tool_input_invalid_input():
def test_validate_tool_input_invalid_input() -> None:
# Create mock agent with proper string values
mock_agent = MagicMock()
mock_agent.key = "test_agent_key" # Must be a string
@@ -348,7 +348,7 @@ def test_validate_tool_input_invalid_input():
assert arguments == {}
def test_validate_tool_input_complex_structure():
def test_validate_tool_input_complex_structure() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -384,7 +384,7 @@ def test_validate_tool_input_complex_structure():
assert arguments == expected_arguments
def test_validate_tool_input_code_content():
def test_validate_tool_input_code_content() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -404,7 +404,7 @@ def test_validate_tool_input_code_content():
assert arguments == expected_arguments
def test_validate_tool_input_with_escaped_quotes():
def test_validate_tool_input_with_escaped_quotes() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -421,7 +421,7 @@ def test_validate_tool_input_with_escaped_quotes():
assert arguments == expected_arguments
def test_validate_tool_input_large_json_content():
def test_validate_tool_input_large_json_content() -> None:
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[],
@@ -441,7 +441,7 @@ def test_validate_tool_input_large_json_content():
assert arguments == expected_arguments
def test_tool_selection_error_event_direct():
def test_tool_selection_error_event_direct() -> None:
"""Test tool selection error event emission directly from ToolUsage class."""
mock_agent = MagicMock()
mock_agent.key = "test_key"
@@ -473,10 +473,10 @@ def test_tool_selection_error_event_direct():
received_events = []
@crewai_event_bus.on(ToolSelectionErrorEvent)
def event_handler(source, event):
def event_handler(source, event) -> None:
received_events.append(event)
with pytest.raises(Exception) as exc_info:
with pytest.raises(Exception):
tool_usage._select_tool("Non Existent Tool")
assert len(received_events) == 1
event = received_events[0]
@@ -490,7 +490,7 @@ def test_tool_selection_error_event_direct():
assert "don't exist" in event.error
received_events.clear()
with pytest.raises(Exception) as exc_info:
with pytest.raises(Exception):
tool_usage._select_tool("")
assert len(received_events) == 1
@@ -504,7 +504,7 @@ def test_tool_selection_error_event_direct():
assert "forgot the Action name" in event.error
def test_tool_validate_input_error_event():
def test_tool_validate_input_error_event() -> None:
"""Test tool validation input error event emission from ToolUsage class."""
# Mock agent and required components
mock_agent = MagicMock()
@@ -558,12 +558,12 @@ def test_tool_validate_input_error_event():
received_events = []
@crewai_event_bus.on(ToolValidateInputErrorEvent)
def event_handler(source, event):
def event_handler(source, event) -> None:
received_events.append(event)
# Test invalid input
invalid_input = "invalid json {[}"
with pytest.raises(Exception) as exc_info:
with pytest.raises(Exception):
tool_usage._validate_tool_input(invalid_input)
# Verify event was emitted
@@ -576,7 +576,7 @@ def test_tool_validate_input_error_event():
assert "must be a valid dictionary" in event.error
def test_tool_usage_finished_event_with_result():
def test_tool_usage_finished_event_with_result() -> None:
"""Test that ToolUsageFinishedEvent is emitted with correct result attributes."""
# Create mock agent with proper string values
mock_agent = MagicMock()
@@ -618,7 +618,7 @@ def test_tool_usage_finished_event_with_result():
received_events = []
@crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event):
def event_handler(source, event) -> None:
received_events.append(event)
# Call on_tool_use_finished with test data
@@ -652,7 +652,7 @@ def test_tool_usage_finished_event_with_result():
assert event.type == "tool_usage_finished"
def test_tool_usage_finished_event_with_cached_result():
def test_tool_usage_finished_event_with_cached_result() -> None:
"""Test that ToolUsageFinishedEvent is emitted with correct result attributes when using cached result."""
# Create mock agent with proper string values
mock_agent = MagicMock()
@@ -694,7 +694,7 @@ def test_tool_usage_finished_event_with_cached_result():
received_events = []
@crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event):
def event_handler(source, event) -> None:
received_events.append(event)
# Call on_tool_use_finished with test data and from_cache=True

View File

@@ -25,15 +25,15 @@ class InternalCrewEvaluator:
return CrewEvaluator(crew, openai_model_name="gpt-4o-mini")
def test_setup_for_evaluating(self, crew_planner):
def test_setup_for_evaluating(self, crew_planner) -> None:
crew_planner._setup_for_evaluating()
assert crew_planner.crew.tasks[0].callback == crew_planner.evaluate
def test_set_iteration(self, crew_planner):
def test_set_iteration(self, crew_planner) -> None:
crew_planner.set_iteration(1)
assert crew_planner.iteration == 1
def test_evaluator_agent(self, crew_planner):
def test_evaluator_agent(self, crew_planner) -> None:
agent = crew_planner._evaluator_agent()
assert agent.role == "Task Execution Evaluator"
assert (
@@ -47,7 +47,7 @@ class InternalCrewEvaluator:
assert agent.verbose is False
assert agent.llm.model == "gpt-4o-mini"
def test_evaluation_task(self, crew_planner):
def test_evaluation_task(self, crew_planner) -> None:
evaluator_agent = Agent(
role="Evaluator Agent",
goal="Evaluate the performance of the agents in the crew",
@@ -60,11 +60,11 @@ class InternalCrewEvaluator:
)
task_output = "Task Output 1"
task = crew_planner._evaluation_task(
evaluator_agent, task_to_evaluate, task_output
evaluator_agent, task_to_evaluate, task_output,
)
assert task.description.startswith(
"Based on the task description and the expected output, compare and evaluate the performance of the agents in the crew based on the Task Output they have performed using score from 1 to 10 evaluating on completion, quality, and overall performance."
"Based on the task description and the expected output, compare and evaluate the performance of the agents in the crew based on the Task Output they have performed using score from 1 to 10 evaluating on completion, quality, and overall performance.",
)
assert task.agent == evaluator_agent
@@ -79,7 +79,7 @@ class InternalCrewEvaluator:
@mock.patch("crewai.utilities.evaluators.crew_evaluator_handler.Console")
@mock.patch("crewai.utilities.evaluators.crew_evaluator_handler.Table")
def test_print_crew_evaluation_result(self, table, console, crew_planner):
def test_print_crew_evaluation_result(self, table, console, crew_planner) -> None:
# Set up task scores and execution times
crew_planner.tasks_scores = {
1: [10, 9, 8],
@@ -97,10 +97,10 @@ class InternalCrewEvaluator:
]
crew_planner.crew.tasks = [
mock.Mock(
agent=crew_planner.crew.agents[0], processed_by_agents=["Agent 1"]
agent=crew_planner.crew.agents[0], processed_by_agents=["Agent 1"],
),
mock.Mock(
agent=crew_planner.crew.agents[1], processed_by_agents=["Agent 2"]
agent=crew_planner.crew.agents[1], processed_by_agents=["Agent 2"],
),
]
@@ -111,7 +111,7 @@ class InternalCrewEvaluator:
table.assert_has_calls(
[
mock.call(
title="Tasks Scores \n (1-10 Higher is better)", box=mock.ANY
title="Tasks Scores \n (1-10 Higher is better)", box=mock.ANY,
), # Title and styling
mock.call().add_column("Tasks/Crew/Agents", style="cyan"), # Columns
mock.call().add_column("Run 1", justify="center"),
@@ -125,15 +125,15 @@ class InternalCrewEvaluator:
# Add crew averages and execution times
mock.call().add_row("Crew", "9.00", "8.00", "8.5", ""),
mock.call().add_row("Execution Time (s)", "135", "155", "145", ""),
]
],
)
# Ensure the console prints the table
console.assert_has_calls([mock.call(), mock.call().print(table())])
def test_evaluate(self, crew_planner):
def test_evaluate(self, crew_planner) -> None:
task_output = TaskOutput(
description="Task 1", agent=str(crew_planner.crew.agents[0])
description="Task 1", agent=str(crew_planner.crew.agents[0]),
)
with mock.patch.object(Task, "execute_sync") as execute:

View File

@@ -8,7 +8,7 @@ from crewai.utilities.evaluators.task_evaluator import (
@patch("crewai.utilities.evaluators.task_evaluator.Converter")
def test_evaluate_training_data(converter_mock):
def test_evaluate_training_data(converter_mock) -> None:
training_data = {
"agent_id": {
"data1": {
@@ -21,7 +21,7 @@ def test_evaluate_training_data(converter_mock):
"human_feedback": "Human feedback 2",
"improved_output": "Improved output 2",
},
}
},
}
agent_id = "agent_id"
original_agent = MagicMock()
@@ -30,7 +30,7 @@ def test_evaluate_training_data(converter_mock):
suggestions=[
"The initial output was already good, having a detailed explanation. However, the improved output "
"gave similar information but in a more professional manner using better vocabulary. For future tasks, "
"try to implement more elaborate language and precise terminology from the beginning."
"try to implement more elaborate language and precise terminology from the beginning.",
],
quality=8.0,
final_summary="The agent responded well initially. However, the improved output showed that there is room "
@@ -39,7 +39,7 @@ def test_evaluate_training_data(converter_mock):
)
converter_mock.return_value.to_pydantic.return_value = function_return_value
result = TaskEvaluator(original_agent=original_agent).evaluate_training_data(
training_data, agent_id
training_data, agent_id,
)
assert result == function_return_value
@@ -61,5 +61,5 @@ def test_evaluate_training_data(converter_mock):
"following structure, with the following keys:\n{\n suggestions: List[str],\n quality: float,\n final_summary: str\n}",
),
mock.call().to_pydantic(),
]
],
)

View File

@@ -1,3 +1,4 @@
from typing import NoReturn
from unittest.mock import Mock
from crewai.utilities.events.base_events import BaseEvent
@@ -8,11 +9,11 @@ class TestEvent(BaseEvent):
pass
def test_specific_event_handler():
def test_specific_event_handler() -> None:
mock_handler = Mock()
@crewai_event_bus.on(TestEvent)
def handler(source, event):
def handler(source, event) -> None:
mock_handler(source, event)
event = TestEvent(type="test_event")
@@ -21,11 +22,11 @@ def test_specific_event_handler():
mock_handler.assert_called_once_with("source_object", event)
def test_wildcard_event_handler():
def test_wildcard_event_handler() -> None:
mock_handler = Mock()
@crewai_event_bus.on(BaseEvent)
def handler(source, event):
def handler(source, event) -> None:
mock_handler(source, event)
event = TestEvent(type="test_event")
@@ -34,10 +35,11 @@ def test_wildcard_event_handler():
mock_handler.assert_called_once_with("source_object", event)
def test_event_bus_error_handling(capfd):
def test_event_bus_error_handling(capfd) -> None:
@crewai_event_bus.on(BaseEvent)
def broken_handler(source, event):
raise ValueError("Simulated handler failure")
def broken_handler(source, event) -> NoReturn:
msg = "Simulated handler failure"
raise ValueError(msg)
event = TestEvent(type="test_event")
crewai_event_bus.emit("source_object", event)

View File

@@ -1,7 +1,4 @@
import unittest
from typing import Any, Dict, List, Union
import pytest
from crewai.utilities.chromadb import (
MAX_COLLECTION_LENGTH,
@@ -12,58 +9,58 @@ from crewai.utilities.chromadb import (
class TestChromadbUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
def test_sanitize_collection_name_long_name(self) -> None:
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_special_chars(self):
def test_sanitize_collection_name_special_chars(self) -> None:
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_short_name(self):
def test_sanitize_collection_name_short_name(self) -> None:
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_bad_ends(self):
def test_sanitize_collection_name_bad_ends(self) -> None:
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_none(self):
def test_sanitize_collection_name_none(self) -> None:
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
assert sanitized == "default_collection"
def test_sanitize_collection_name_ipv4_pattern(self):
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
assert sanitized.startswith("ip_")
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_is_ipv4_pattern(self):
def test_is_ipv4_pattern(self) -> None:
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
assert is_ipv4_pattern("192.168.1.1")
assert not is_ipv4_pattern("not.an.ip.address")
def test_sanitize_collection_name_properties(self):
def test_sanitize_collection_name_properties(self) -> None:
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
@@ -75,7 +72,7 @@ class TestChromadbUtils(unittest.TestCase):
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()

View File

@@ -1,6 +1,4 @@
import json
import os
from typing import Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
@@ -73,7 +71,7 @@ def mock_agent():
# Tests for convert_to_model
def test_convert_to_model_with_valid_json():
def test_convert_to_model_with_valid_json() -> None:
result = '{"name": "John", "age": 30}'
output = convert_to_model(result, SimpleModel, None, None)
assert isinstance(output, SimpleModel)
@@ -81,7 +79,7 @@ def test_convert_to_model_with_valid_json():
assert output.age == 30
def test_convert_to_model_with_invalid_json():
def test_convert_to_model_with_invalid_json() -> None:
result = '{"name": "John", "age": "thirty"}'
with patch("crewai.utilities.converter.handle_partial_json") as mock_handle:
mock_handle.return_value = "Fallback result"
@@ -89,13 +87,13 @@ def test_convert_to_model_with_invalid_json():
assert output == "Fallback result"
def test_convert_to_model_with_no_model():
def test_convert_to_model_with_no_model() -> None:
result = "Plain text"
output = convert_to_model(result, None, None, None)
assert output == "Plain text"
def test_convert_to_model_with_special_characters():
def test_convert_to_model_with_special_characters() -> None:
json_string_test = """
{
"responses": [
@@ -114,15 +112,15 @@ def test_convert_to_model_with_special_characters():
)
def test_convert_to_model_with_escaped_special_characters():
def test_convert_to_model_with_escaped_special_characters() -> None:
json_string_test = json.dumps(
{
"responses": [
{
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
}
]
}
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on",
},
],
},
)
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
@@ -133,7 +131,7 @@ def test_convert_to_model_with_escaped_special_characters():
)
def test_convert_to_model_with_multiple_special_characters():
def test_convert_to_model_with_multiple_special_characters() -> None:
json_string_test = """
{
"responses": [
@@ -153,7 +151,7 @@ def test_convert_to_model_with_multiple_special_characters():
# Tests for validate_model
def test_validate_model_pydantic_output():
def test_validate_model_pydantic_output() -> None:
result = '{"name": "Alice", "age": 25}'
output = validate_model(result, SimpleModel, False)
assert isinstance(output, SimpleModel)
@@ -161,7 +159,7 @@ def test_validate_model_pydantic_output():
assert output.age == 25
def test_validate_model_json_output():
def test_validate_model_json_output() -> None:
result = '{"name": "Bob", "age": 40}'
output = validate_model(result, SimpleModel, True)
assert isinstance(output, dict)
@@ -169,7 +167,7 @@ def test_validate_model_json_output():
# Tests for handle_partial_json
def test_handle_partial_json_with_valid_partial():
def test_handle_partial_json_with_valid_partial() -> None:
result = 'Some text {"name": "Charlie", "age": 35} more text'
output = handle_partial_json(result, SimpleModel, False, None)
assert isinstance(output, SimpleModel)
@@ -177,7 +175,7 @@ def test_handle_partial_json_with_valid_partial():
assert output.age == 35
def test_handle_partial_json_with_invalid_partial(mock_agent):
def test_handle_partial_json_with_invalid_partial(mock_agent) -> None:
result = "No valid JSON here"
with patch("crewai.utilities.converter.convert_with_instructions") as mock_convert:
mock_convert.return_value = "Converted result"
@@ -189,8 +187,8 @@ def test_handle_partial_json_with_invalid_partial(mock_agent):
@patch("crewai.utilities.converter.create_converter")
@patch("crewai.utilities.converter.get_conversion_instructions")
def test_convert_with_instructions_success(
mock_get_instructions, mock_create_converter, mock_agent
):
mock_get_instructions, mock_create_converter, mock_agent,
) -> None:
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = SimpleModel(name="David", age=50)
@@ -207,8 +205,8 @@ def test_convert_with_instructions_success(
@patch("crewai.utilities.converter.create_converter")
@patch("crewai.utilities.converter.get_conversion_instructions")
def test_convert_with_instructions_failure(
mock_get_instructions, mock_create_converter, mock_agent
):
mock_get_instructions, mock_create_converter, mock_agent,
) -> None:
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = ConverterError("Conversion failed")
@@ -222,7 +220,7 @@ def test_convert_with_instructions_failure(
# Tests for get_conversion_instructions
def test_get_conversion_instructions_gpt():
def test_get_conversion_instructions_gpt() -> None:
llm = LLM(model="gpt-4o-mini")
with patch.object(LLM, "supports_function_calling") as supports_function_calling:
supports_function_calling.return_value = True
@@ -237,7 +235,7 @@ def test_get_conversion_instructions_gpt():
assert instructions == expected_instructions
def test_get_conversion_instructions_non_gpt():
def test_get_conversion_instructions_non_gpt() -> None:
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
with patch.object(LLM, "supports_function_calling", return_value=False):
instructions = get_conversion_instructions(SimpleModel, llm)
@@ -246,17 +244,17 @@ def test_get_conversion_instructions_non_gpt():
# Tests for is_gpt
def test_supports_function_calling_true():
def test_supports_function_calling_true() -> None:
llm = LLM(model="gpt-4o")
assert llm.supports_function_calling() is True
def test_supports_function_calling_false():
def test_supports_function_calling_false() -> None:
llm = LLM(model="non-existent-model")
assert llm.supports_function_calling() is False
def test_create_converter_with_mock_agent():
def test_create_converter_with_mock_agent() -> None:
mock_agent = MagicMock()
mock_agent.get_output_converter.return_value = MagicMock(spec=Converter)
@@ -272,7 +270,7 @@ def test_create_converter_with_mock_agent():
mock_agent.get_output_converter.assert_called_once()
def test_create_converter_with_custom_converter():
def test_create_converter_with_custom_converter() -> None:
converter = create_converter(
converter_cls=CustomConverter,
llm=LLM(model="gpt-4o-mini"),
@@ -284,22 +282,22 @@ def test_create_converter_with_custom_converter():
assert isinstance(converter, CustomConverter)
def test_create_converter_fails_without_agent_or_converter_cls():
def test_create_converter_fails_without_agent_or_converter_cls() -> None:
with pytest.raises(
ValueError, match="Either agent or converter_cls must be provided"
ValueError, match="Either agent or converter_cls must be provided",
):
create_converter(
llm=Mock(), text="Sample", model=SimpleModel, instructions="Convert"
llm=Mock(), text="Sample", model=SimpleModel, instructions="Convert",
)
def test_generate_model_description_simple_model():
def test_generate_model_description_simple_model() -> None:
description = generate_model_description(SimpleModel)
expected_description = '{\n "name": str,\n "age": int\n}'
assert description == expected_description
def test_generate_model_description_nested_model():
def test_generate_model_description_nested_model() -> None:
description = generate_model_description(NestedModel)
expected_description = (
'{\n "id": int,\n "data": {\n "name": str,\n "age": int\n}\n}'
@@ -307,9 +305,9 @@ def test_generate_model_description_nested_model():
assert description == expected_description
def test_generate_model_description_optional_field():
def test_generate_model_description_optional_field() -> None:
class ModelWithOptionalField(BaseModel):
name: Optional[str]
name: str | None
age: int
description = generate_model_description(ModelWithOptionalField)
@@ -317,18 +315,18 @@ def test_generate_model_description_optional_field():
assert description == expected_description
def test_generate_model_description_list_field():
def test_generate_model_description_list_field() -> None:
class ModelWithListField(BaseModel):
items: List[int]
items: list[int]
description = generate_model_description(ModelWithListField)
expected_description = '{\n "items": List[int]\n}'
assert description == expected_description
def test_generate_model_description_dict_field():
def test_generate_model_description_dict_field() -> None:
class ModelWithDictField(BaseModel):
attributes: Dict[str, int]
attributes: dict[str, int]
description = generate_model_description(ModelWithDictField)
expected_description = '{\n "attributes": Dict[str, int]\n}'
@@ -336,7 +334,7 @@ def test_generate_model_description_dict_field():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_convert_with_instructions():
def test_convert_with_instructions() -> None:
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: Alice, Age: 30"
@@ -358,7 +356,7 @@ def test_convert_with_instructions():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_llama3_2_model():
def test_converter_with_llama3_2_model() -> None:
llm = LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
@@ -375,7 +373,7 @@ def test_converter_with_llama3_2_model():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_llama3_1_model():
def test_converter_with_llama3_1_model() -> None:
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
@@ -392,7 +390,7 @@ def test_converter_with_llama3_1_model():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_nested_model():
def test_converter_with_nested_model() -> None:
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
@@ -416,7 +414,7 @@ def test_converter_with_nested_model():
# Tests for error handling
def test_converter_error_handling():
def test_converter_error_handling() -> None:
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = "Invalid JSON"
@@ -431,13 +429,13 @@ def test_converter_error_handling():
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
converter.to_pydantic()
assert "Failed to convert text into a Pydantic model" in str(exc_info.value)
# Tests for retry logic
def test_converter_retry_logic():
def test_converter_retry_logic() -> None:
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.side_effect = [
@@ -465,10 +463,10 @@ def test_converter_retry_logic():
# Tests for optional fields
def test_converter_with_optional_fields():
def test_converter_with_optional_fields() -> None:
class OptionalModel(BaseModel):
name: str
age: Optional[int]
age: int | None
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
@@ -492,9 +490,9 @@ def test_converter_with_optional_fields():
# Tests for list fields
def test_converter_with_list_field():
def test_converter_with_list_field() -> None:
class ListModel(BaseModel):
items: List[int]
items: list[int]
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
@@ -519,7 +517,7 @@ def test_converter_with_list_field():
from enum import Enum
def test_converter_with_enum():
def test_converter_with_enum() -> None:
class Color(Enum):
RED = "red"
GREEN = "green"
@@ -550,7 +548,7 @@ def test_converter_with_enum():
# Tests for ambiguous input
def test_converter_with_ambiguous_input():
def test_converter_with_ambiguous_input() -> None:
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"name": "Charlie", "age": "Not an age"}'
@@ -565,13 +563,13 @@ def test_converter_with_ambiguous_input():
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
converter.to_pydantic()
assert "failed to convert text into a pydantic model" in str(exc_info.value).lower()
# Tests for function calling support
def test_converter_with_function_calling():
def test_converter_with_function_calling() -> None:
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = True
@@ -594,7 +592,7 @@ def test_converter_with_function_calling():
instructor.to_pydantic.assert_called_once()
def test_generate_model_description_union_field():
def test_generate_model_description_union_field() -> None:
class UnionModel(BaseModel):
field: int | str | None

View File

@@ -1,5 +1,5 @@
import os
from datetime import datetime
from typing import NoReturn
from unittest.mock import Mock, patch
import pytest
@@ -38,7 +38,6 @@ from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.task_events import (
@@ -74,21 +73,21 @@ event_listener = EventListener()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_start_kickoff_event():
def test_crew_emits_start_kickoff_event() -> None:
received_events = []
mock_span = Mock()
@crewai_event_bus.on(CrewKickoffStartedEvent)
def handle_crew_start(source, event):
def handle_crew_start(source, event) -> None:
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
with (
patch.object(
event_listener._telemetry, "crew_execution_span", return_value=mock_span
event_listener._telemetry, "crew_execution_span", return_value=mock_span,
) as mock_crew_execution_span,
patch.object(
event_listener._telemetry, "end_crew", return_value=mock_span
event_listener._telemetry, "end_crew", return_value=mock_span,
) as mock_crew_ended,
):
crew.kickoff()
@@ -102,11 +101,11 @@ def test_crew_emits_start_kickoff_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_end_kickoff_event():
def test_crew_emits_end_kickoff_event() -> None:
received_events = []
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def handle_crew_end(source, event):
def handle_crew_end(source, event) -> None:
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
@@ -120,22 +119,22 @@ def test_crew_emits_end_kickoff_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_test_kickoff_type_event():
def test_crew_emits_test_kickoff_type_event() -> None:
received_events = []
mock_span = Mock()
@crewai_event_bus.on(CrewTestStartedEvent)
def handle_crew_end(source, event):
def handle_crew_end(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(CrewTestCompletedEvent)
def handle_crew_test_end(source, event):
def handle_crew_test_end(source, event) -> None:
received_events.append(event)
eval_llm = LLM(model="gpt-4o-mini")
with (
patch.object(
event_listener._telemetry, "test_execution_span", return_value=mock_span
event_listener._telemetry, "test_execution_span", return_value=mock_span,
) as mock_crew_execution_span,
):
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
@@ -159,13 +158,13 @@ def test_crew_emits_test_kickoff_type_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_kickoff_failed_event():
def test_crew_emits_kickoff_failed_event() -> None:
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(CrewKickoffFailedEvent)
def handle_crew_failed(source, event):
def handle_crew_failed(source, event) -> None:
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
@@ -184,11 +183,11 @@ def test_crew_emits_kickoff_failed_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_start_task_event():
def test_crew_emits_start_task_event() -> None:
received_events = []
@crewai_event_bus.on(TaskStartedEvent)
def handle_task_start(source, event):
def handle_task_start(source, event) -> None:
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
@@ -201,21 +200,21 @@ def test_crew_emits_start_task_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_end_task_event():
def test_crew_emits_end_task_event() -> None:
received_events = []
@crewai_event_bus.on(TaskCompletedEvent)
def handle_task_end(source, event):
def handle_task_end(source, event) -> None:
received_events.append(event)
mock_span = Mock()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
with (
patch.object(
event_listener._telemetry, "task_started", return_value=mock_span
event_listener._telemetry, "task_started", return_value=mock_span,
) as mock_task_started,
patch.object(
event_listener._telemetry, "task_ended", return_value=mock_span
event_listener._telemetry, "task_ended", return_value=mock_span,
) as mock_task_ended,
):
crew.kickoff()
@@ -229,12 +228,12 @@ def test_crew_emits_end_task_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_emits_failed_event_on_execution_error():
def test_task_emits_failed_event_on_execution_error() -> None:
received_events = []
received_sources = []
@crewai_event_bus.on(TaskFailedEvent)
def handle_task_failed(source, event):
def handle_task_failed(source, event) -> None:
received_events.append(event)
received_sources.append(source)
@@ -266,15 +265,15 @@ def test_task_emits_failed_event_on_execution_error():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_started_and_completed_events():
def test_agent_emits_execution_started_and_completed_events() -> None:
received_events = []
@crewai_event_bus.on(AgentExecutionStartedEvent)
def handle_agent_start(source, event):
def handle_agent_start(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def handle_agent_completed(source, event):
def handle_agent_completed(source, event) -> None:
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
@@ -295,21 +294,21 @@ def test_agent_emits_execution_started_and_completed_events():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_error_event():
def test_agent_emits_execution_error_event() -> None:
received_events = []
@crewai_event_bus.on(AgentExecutionErrorEvent)
def handle_agent_start(source, event):
def handle_agent_start(source, event) -> None:
received_events.append(event)
error_message = "Error happening while sending prompt to model."
base_agent.max_retry_limit = 0
with patch.object(
CrewAgentExecutor, "invoke", wraps=base_agent.agent_executor.invoke
CrewAgentExecutor, "invoke", wraps=base_agent.agent_executor.invoke,
) as invoke_mock:
invoke_mock.side_effect = Exception(error_message)
with pytest.raises(Exception) as e:
with pytest.raises(Exception):
base_agent.execute_task(
task=base_task,
)
@@ -325,7 +324,7 @@ def test_agent_emits_execution_error_event():
class SayHiTool(BaseTool):
name: str = Field(default="say_hi", description="The name of the tool")
description: str = Field(
default="Say hi", description="The description of the tool"
default="Say hi", description="The description of the tool",
)
def _run(self) -> str:
@@ -333,11 +332,11 @@ class SayHiTool(BaseTool):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_finished_events():
def test_tools_emits_finished_events() -> None:
received_events = []
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
def handle_tool_end(source, event) -> None:
received_events.append(event)
agent = Agent(
@@ -364,16 +363,16 @@ def test_tools_emits_finished_events():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_error_events():
def test_tools_emits_error_events() -> None:
received_events = []
@crewai_event_bus.on(ToolUsageErrorEvent)
def handle_tool_end(source, event):
def handle_tool_end(source, event) -> None:
received_events.append(event)
class ErrorTool(BaseTool):
name: str = Field(
default="error_tool", description="A tool that raises an error"
default="error_tool", description="A tool that raises an error",
)
description: str = Field(
default="This tool always raises an error",
@@ -381,7 +380,8 @@ def test_tools_emits_error_events():
)
def _run(self) -> str:
raise Exception("Simulated tool error")
msg = "Simulated tool error"
raise Exception(msg)
agent = Agent(
role="base_agent",
@@ -410,22 +410,22 @@ def test_tools_emits_error_events():
assert isinstance(received_events[0].timestamp, datetime)
def test_flow_emits_start_event():
def test_flow_emits_start_event() -> None:
received_events = []
mock_span = Mock()
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
def handle_flow_start(source, event) -> None:
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
def begin(self) -> str:
return "started"
with (
patch.object(
event_listener._telemetry, "flow_execution_span", return_value=mock_span
event_listener._telemetry, "flow_execution_span", return_value=mock_span,
) as mock_flow_execution_span,
):
flow = TestFlow()
@@ -437,18 +437,18 @@ def test_flow_emits_start_event():
assert received_events[0].type == "flow_started"
def test_flow_emits_finish_event():
def test_flow_emits_finish_event() -> None:
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_finish(source, event):
def handle_flow_finish(source, event) -> None:
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
def begin(self) -> str:
return "completed"
flow = TestFlow()
@@ -461,23 +461,22 @@ def test_flow_emits_finish_event():
assert result == "completed"
def test_flow_emits_method_execution_started_event():
def test_flow_emits_method_execution_started_event() -> None:
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
print("event in method name", event.method_name)
def handle_method_start(source, event) -> None:
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
def begin(self) -> str:
return "started"
@listen("begin")
def second_method(self):
def second_method(self) -> str:
return "executed"
flow = TestFlow()
@@ -495,10 +494,10 @@ def test_flow_emits_method_execution_started_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_register_handler_adds_new_handler():
def test_register_handler_adds_new_handler() -> None:
received_events = []
def custom_handler(source, event):
def custom_handler(source, event) -> None:
received_events.append(event)
with crewai_event_bus.scoped_handlers():
@@ -513,14 +512,14 @@ def test_register_handler_adds_new_handler():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_multiple_handlers_for_same_event():
def test_multiple_handlers_for_same_event() -> None:
received_events_1 = []
received_events_2 = []
def handler_1(source, event):
def handler_1(source, event) -> None:
received_events_1.append(event)
def handler_2(source, event):
def handler_2(source, event) -> None:
received_events_2.append(event)
with crewai_event_bus.scoped_handlers():
@@ -536,22 +535,22 @@ def test_multiple_handlers_for_same_event():
assert received_events_2[0].type == "crew_kickoff_started"
def test_flow_emits_created_event():
def test_flow_emits_created_event() -> None:
received_events = []
mock_span = Mock()
@crewai_event_bus.on(FlowCreatedEvent)
def handle_flow_created(source, event):
def handle_flow_created(source, event) -> None:
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
def begin(self) -> str:
return "started"
with (
patch.object(
event_listener._telemetry, "flow_creation_span", return_value=mock_span
event_listener._telemetry, "flow_creation_span", return_value=mock_span,
) as mock_flow_creation_span,
):
flow = TestFlow()
@@ -564,17 +563,17 @@ def test_flow_emits_created_event():
assert received_events[0].type == "flow_created"
def test_flow_emits_method_execution_failed_event():
def test_flow_emits_method_execution_failed_event() -> None:
received_events = []
error = Exception("Simulated method failure")
@crewai_event_bus.on(MethodExecutionFailedEvent)
def handle_method_failed(source, event):
def handle_method_failed(source, event) -> None:
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
def begin(self) -> NoReturn:
raise error
flow = TestFlow()
@@ -589,15 +588,15 @@ def test_flow_emits_method_execution_failed_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_call_started_event():
def test_llm_emits_call_started_event() -> None:
received_events = []
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_call_started(source, event):
def handle_llm_call_started(source, event) -> None:
received_events.append(event)
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_call_completed(source, event):
def handle_llm_call_completed(source, event) -> None:
received_events.append(event)
llm = LLM(model="gpt-4o-mini")
@@ -609,11 +608,11 @@ def test_llm_emits_call_started_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_call_failed_event():
def test_llm_emits_call_failed_event() -> None:
received_events = []
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_call_failed(source, event):
def handle_llm_call_failed(source, event) -> None:
received_events.append(event)
error_message = "Simulated LLM call failure"
@@ -629,14 +628,14 @@ def test_llm_emits_call_failed_event():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_stream_chunk_events():
def test_llm_emits_stream_chunk_events() -> None:
"""Test that LLM emits stream chunk events when streaming is enabled."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
def handle_stream_chunk(source, event) -> None:
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
@@ -653,14 +652,14 @@ def test_llm_emits_stream_chunk_events():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_no_stream_chunks_when_streaming_disabled():
def test_llm_no_stream_chunks_when_streaming_disabled() -> None:
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
def handle_stream_chunk(source, event) -> None:
received_chunks.append(event.chunk)
# Create an LLM with streaming disabled
@@ -673,11 +672,12 @@ def test_llm_no_stream_chunks_when_streaming_disabled():
assert len(received_chunks) == 0
# Verify we got a response
assert response and isinstance(response, str)
assert response
assert isinstance(response, str)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_fallback_to_non_streaming():
def test_streaming_fallback_to_non_streaming() -> None:
"""Test that streaming falls back to non-streaming when there's an error."""
received_chunks = []
fallback_called = False
@@ -685,7 +685,7 @@ def test_streaming_fallback_to_non_streaming():
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
def handle_stream_chunk(source, event) -> None:
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
@@ -695,7 +695,7 @@ def test_streaming_fallback_to_non_streaming():
original_call = llm.call
# Create a mock call method that handles the streaming error
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
def mock_call(messages, tools=None, callbacks=None, available_functions=None) -> str:
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
@@ -731,14 +731,14 @@ def test_streaming_fallback_to_non_streaming():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_empty_response_handling():
def test_streaming_empty_response_handling() -> None:
"""Test that streaming handles empty responses correctly."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
def handle_stream_chunk(source, event) -> None:
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
@@ -748,7 +748,7 @@ def test_streaming_empty_response_handling():
original_call = llm.call
# Create a mock call method that simulates empty chunks
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
def mock_call(messages, tools=None, callbacks=None, available_functions=None) -> str:
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
@@ -768,7 +768,8 @@ def test_streaming_empty_response_handling():
assert all(chunk == "" for chunk in received_chunks)
# Verify the response is the default message for empty responses
assert "I apologize" in response and "couldn't generate" in response
assert "I apologize" in response
assert "couldn't generate" in response
finally:
# Restore the original method

View File

@@ -7,16 +7,16 @@ from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.file_name = "test_data.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
def tearDown(self) -> None:
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_initialize_file(self):
def test_initialize_file(self) -> None:
assert os.path.exists(self.file_path) is False
self.handler.initialize_file()
@@ -24,17 +24,17 @@ class TestPickleHandler(unittest.TestCase):
assert os.path.exists(self.file_path) is True
assert os.path.getsize(self.file_path) >= 0
def test_save_and_load(self):
def test_save_and_load(self) -> None:
data = {"key": "value"}
self.handler.save(data)
loaded_data = self.handler.load()
assert loaded_data == data
def test_load_empty_file(self):
def test_load_empty_file(self) -> None:
loaded_data = self.handler.load()
assert loaded_data == {}
def test_load_corrupted_file(self):
def test_load_corrupted_file(self) -> None:
with open(self.file_path, "wb") as file:
file.write(b"corrupted data")
@@ -42,4 +42,4 @@ class TestPickleHandler(unittest.TestCase):
self.handler.load()
assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
assert str(exc.type) == "<class '_pickle.UnpicklingError'>"

View File

@@ -3,38 +3,38 @@ import pytest
from crewai.utilities.i18n import I18N
def test_load_prompts():
def test_load_prompts() -> None:
i18n = I18N()
i18n.load_prompts()
assert i18n._prompts is not None
def test_slice():
def test_slice() -> None:
i18n = I18N()
i18n.load_prompts()
assert isinstance(i18n.slice("role_playing"), str)
def test_tools():
def test_tools() -> None:
i18n = I18N()
i18n.load_prompts()
assert isinstance(i18n.tools("ask_question"), str)
def test_retrieve():
def test_retrieve() -> None:
i18n = I18N()
i18n.load_prompts()
assert isinstance(i18n.retrieve("slices", "role_playing"), str)
def test_retrieve_not_found():
def test_retrieve_not_found() -> None:
i18n = I18N()
i18n.load_prompts()
with pytest.raises(Exception):
i18n.retrieve("nonexistent_kind", "nonexistent_key")
def test_prompt_file():
def test_prompt_file() -> None:
import os
path = os.path.join(os.path.dirname(__file__), "prompts.json")

View File

@@ -1,5 +1,4 @@
"""
Tests for verifying the integration of knowledge sources in the planning process.
"""Tests for verifying the integration of knowledge sources in the planning process.
This module ensures that agent knowledge is properly included during task planning.
"""
@@ -15,11 +14,12 @@ from crewai.utilities.planning_handler import CrewPlanner
@pytest.fixture
def mock_knowledge_source():
"""
Create a mock knowledge source with test content.
"""Create a mock knowledge source with test content.
Returns:
StringKnowledgeSource:
A knowledge source containing AI-related test content
A knowledge source containing AI-related test content.
"""
content = """
Important context about AI:
@@ -29,13 +29,13 @@ def mock_knowledge_source():
"""
return StringKnowledgeSource(content=content)
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_knowledge_included_in_planning(mock_chroma):
@patch("crewai.knowledge.storage.knowledge_storage.chromadb")
def test_knowledge_included_in_planning(mock_chroma) -> None:
"""Test that verifies knowledge sources are properly included in planning."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
mock_collection.add.return_value = None
# Create an agent with knowledge
agent = Agent(
role="AI Researcher",
@@ -43,16 +43,16 @@ def test_knowledge_included_in_planning(mock_chroma):
backstory="Expert in artificial intelligence",
knowledge_sources=[
StringKnowledgeSource(
content="AI systems require careful training and validation."
)
]
content="AI systems require careful training and validation.",
),
],
)
# Create a task for the agent
task = Task(
description="Explain the basics of AI systems",
expected_output="A clear explanation of AI fundamentals",
agent=agent
agent=agent,
)
# Create a crew planner

View File

@@ -8,25 +8,25 @@ from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm
def test_create_llm_with_llm_instance():
def test_create_llm_with_llm_instance() -> None:
existing_llm = LLM(model="gpt-4o")
llm = create_llm(llm_value=existing_llm)
assert llm is existing_llm
def test_create_llm_with_valid_model_string():
def test_create_llm_with_valid_model_string() -> None:
llm = create_llm(llm_value="gpt-4o")
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
def test_create_llm_with_invalid_model_string():
def test_create_llm_with_invalid_model_string() -> None:
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value="invalid-model")
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_unknown_object_missing_attributes():
def test_create_llm_with_unknown_object_missing_attributes() -> None:
class UnknownObject:
pass
@@ -38,7 +38,7 @@ def test_create_llm_with_unknown_object_missing_attributes():
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_none_uses_default_model():
def test_create_llm_with_none_uses_default_model() -> None:
with patch.dict(os.environ, {}, clear=True):
with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"):
llm = create_llm(llm_value=None)
@@ -46,7 +46,7 @@ def test_create_llm_with_none_uses_default_model():
assert llm.model == "gpt-4o-mini"
def test_create_llm_with_unknown_object():
def test_create_llm_with_unknown_object() -> None:
class UnknownObject:
model_name = "gpt-4o"
temperature = 0.7
@@ -60,7 +60,7 @@ def test_create_llm_with_unknown_object():
assert llm.max_tokens == 1500
def test_create_llm_from_env_with_unaccepted_attributes():
def test_create_llm_from_env_with_unaccepted_attributes() -> None:
with patch.dict(
os.environ,
{
@@ -78,7 +78,7 @@ def test_create_llm_from_env_with_unaccepted_attributes():
assert not hasattr(llm, "AWS_REGION_NAME")
def test_create_llm_with_partial_attributes():
def test_create_llm_with_partial_attributes() -> None:
class PartialAttributes:
model_name = "gpt-4o"
# temperature is missing
@@ -90,7 +90,7 @@ def test_create_llm_with_partial_attributes():
assert llm.temperature is None # Should handle missing attributes gracefully
def test_create_llm_with_invalid_type():
def test_create_llm_with_invalid_type() -> None:
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value=42)
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])

View File

@@ -1,8 +1,6 @@
from typing import Optional
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from pydantic import BaseModel
from crewai.agent import Agent
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
@@ -45,12 +43,12 @@ class InternalCrewPlanner:
description="Task 1",
expected_output="Output 1",
agent=Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1"),
)
),
]
planning_agent_llm = "gpt-3.5-turbo"
return CrewPlanner(tasks, planning_agent_llm)
def test_handle_crew_planning(self, crew_planner):
def test_handle_crew_planning(self, crew_planner) -> None:
list_of_plans_per_task = [
PlanPerTask(task="Task1", plan="Plan 1"),
PlanPerTask(task="Task2", plan="Plan 2"),
@@ -61,7 +59,7 @@ class InternalCrewPlanner:
description="Description",
agent="agent",
pydantic=PlannerTaskPydanticOutput(
list_of_plans_per_task=list_of_plans_per_task
list_of_plans_per_task=list_of_plans_per_task,
),
)
result = crew_planner._handle_crew_planning()
@@ -70,12 +68,12 @@ class InternalCrewPlanner:
assert len(result.list_of_plans_per_task) == len(crew_planner.tasks)
execute.assert_called_once()
def test_create_planning_agent(self, crew_planner):
def test_create_planning_agent(self, crew_planner) -> None:
agent = crew_planner._create_planning_agent()
assert isinstance(agent, Agent)
assert agent.role == "Task Execution Planner"
def test_create_planner_task(self, crew_planner):
def test_create_planner_task(self, crew_planner) -> None:
planning_agent = Agent(
role="Planning Agent",
goal="Plan Step by Step Plan",
@@ -92,7 +90,7 @@ class InternalCrewPlanner:
== "Step by step plan on how the agents can execute their tasks using the available tools with mastery"
)
def test_create_tasks_summary(self, crew_planner):
def test_create_tasks_summary(self, crew_planner) -> None:
tasks_summary = crew_planner._create_tasks_summary()
assert isinstance(tasks_summary, str)
assert tasks_summary.startswith("\n Task Number 1 - Task 1")
@@ -100,8 +98,8 @@ class InternalCrewPlanner:
# Knowledge field should not be present when empty
assert '"agent_knowledge"' not in tasks_summary
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_create_tasks_summary_with_knowledge_and_tools(self, mock_chroma):
@patch("crewai.knowledge.storage.knowledge_storage.chromadb")
def test_create_tasks_summary_with_knowledge_and_tools(self, mock_chroma) -> None:
"""Test task summary generation with both knowledge and tools present."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
@@ -112,20 +110,20 @@ class InternalCrewPlanner:
name: str
description: str
def __init__(self, name: str, description: str):
def __init__(self, name: str, description: str) -> None:
tool_data = {"name": name, "description": description}
super().__init__(**tool_data)
def __str__(self):
def __str__(self) -> str:
return self.name
def __repr__(self):
def __repr__(self) -> str:
return self.name
def to_structured_tool(self):
return self
def _run(self, *args, **kwargs):
def _run(self, *args, **kwargs) -> None:
pass
def _generate_description(self) -> str:
@@ -145,9 +143,9 @@ class InternalCrewPlanner:
backstory="Test Backstory",
tools=[tool1, tool2],
knowledge_sources=[
StringKnowledgeSource(content="Test knowledge content")
]
)
StringKnowledgeSource(content="Test knowledge content"),
],
),
)
# Create planner with the new task
@@ -163,13 +161,13 @@ class InternalCrewPlanner:
assert task.agent.role in tasks_summary
assert task.agent.goal in tasks_summary
def test_handle_crew_planning_different_llm(self, crew_planner_different_llm):
def test_handle_crew_planning_different_llm(self, crew_planner_different_llm) -> None:
with patch.object(Task, "execute_sync") as execute:
execute.return_value = TaskOutput(
description="Description",
agent="agent",
pydantic=PlannerTaskPydanticOutput(
list_of_plans_per_task=[PlanPerTask(task="Task1", plan="Plan 1")]
list_of_plans_per_task=[PlanPerTask(task="Task1", plan="Plan 1")],
),
)
result = crew_planner_different_llm._handle_crew_planning()
@@ -177,6 +175,6 @@ class InternalCrewPlanner:
assert crew_planner_different_llm.planning_agent_llm == "gpt-3.5-turbo"
assert isinstance(result, PlannerTaskPydanticOutput)
assert len(result.list_of_plans_per_task) == len(
crew_planner_different_llm.tasks
crew_planner_different_llm.tasks,
)
execute.assert_called_once()

View File

@@ -1,12 +1,10 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
def test_simple_model():
def test_simple_model() -> None:
class SimpleModel(BaseModel):
field1: int
field2: str
@@ -21,7 +19,7 @@ def test_simple_model():
assert schema.strip() == expected_schema.strip()
def test_nested_model():
def test_nested_model() -> None:
class NestedModel(BaseModel):
nested_field: int
@@ -42,9 +40,9 @@ def test_nested_model():
assert schema.strip() == expected_schema.strip()
def test_model_with_list():
def test_model_with_list() -> None:
class ListModel(BaseModel):
list_field: List[int]
list_field: list[int]
parser = PydanticSchemaParser(model=ListModel)
schema = parser.get_schema()
@@ -55,9 +53,9 @@ def test_model_with_list():
assert schema.strip() == expected_schema.strip()
def test_model_with_optional_field():
def test_model_with_optional_field() -> None:
class OptionalModel(BaseModel):
optional_field: Optional[str]
optional_field: str | None
parser = PydanticSchemaParser(model=OptionalModel)
schema = parser.get_schema()
@@ -68,9 +66,9 @@ def test_model_with_optional_field():
assert schema.strip() == expected_schema.strip()
def test_model_with_union():
def test_model_with_union() -> None:
class UnionModel(BaseModel):
union_field: Union[int, str]
union_field: int | str
parser = PydanticSchemaParser(model=UnionModel)
schema = parser.get_schema()
@@ -81,9 +79,9 @@ def test_model_with_union():
assert schema.strip() == expected_schema.strip()
def test_model_with_dict():
def test_model_with_dict() -> None:
class DictModel(BaseModel):
dict_field: Dict[str, int]
dict_field: dict[str, int]
parser = PydanticSchemaParser(model=DictModel)
schema = parser.get_schema()

View File

@@ -1,6 +1,4 @@
from datetime import date, datetime
from typing import List
from unittest.mock import Mock
import pytest
from pydantic import BaseModel
@@ -19,11 +17,11 @@ class Person(BaseModel):
age: int
address: Address
birthday: date
skills: List[str]
skills: list[str]
@pytest.mark.parametrize(
"test_input,expected",
("test_input", "expected"),
[
({"text": "hello world"}, {"text": "hello world"}),
({"number": 42}, {"number": 42}),
@@ -36,25 +34,25 @@ class Person(BaseModel):
({"nested": [1, [2, 3], {4, 5}]}, {"nested": [1, [2, 3], [4, 5]]}),
],
)
def test_basic_serialization(test_input, expected):
def test_basic_serialization(test_input, expected) -> None:
result = to_serializable(test_input)
assert result == expected
@pytest.mark.parametrize(
"input_date,expected",
("input_date", "expected"),
[
(date(2024, 1, 1), "2024-01-01"),
(datetime(2024, 1, 1, 12, 30), "2024-01-01T12:30:00"),
],
)
def test_temporal_serialization(input_date, expected):
def test_temporal_serialization(input_date, expected) -> None:
result = to_serializable({"date": input_date})
assert result["date"] == expected
@pytest.mark.parametrize(
"key,value,expected_key_type",
("key", "value", "expected_key_type"),
[
(("tuple", "key"), "value", str),
(None, "value", str),
@@ -62,7 +60,7 @@ def test_temporal_serialization(input_date, expected):
("normal", "value", str),
],
)
def test_dictionary_key_serialization(key, value, expected_key_type):
def test_dictionary_key_serialization(key, value, expected_key_type) -> None:
result = to_serializable({key: value})
assert len(result) == 1
result_key = next(iter(result.keys()))
@@ -71,19 +69,19 @@ def test_dictionary_key_serialization(key, value, expected_key_type):
@pytest.mark.parametrize(
"callable_obj,expected_in_result",
("callable_obj", "expected_in_result"),
[
(lambda x: x * 2, "lambda"),
(str.upper, "upper"),
],
)
def test_callable_serialization(callable_obj, expected_in_result):
def test_callable_serialization(callable_obj, expected_in_result) -> None:
result = to_serializable({"func": callable_obj})
assert isinstance(result["func"], str)
assert expected_in_result in result["func"].lower()
def test_pydantic_model_serialization():
def test_pydantic_model_serialization() -> None:
address = Address(street="123 Main St", city="Tech City", country="Pythonia")
person = Person(
@@ -108,8 +106,8 @@ def test_pydantic_model_serialization():
)
def test_depth_limit():
"""Test max depth handling with a deeply nested structure"""
def test_depth_limit() -> None:
"""Test max depth handling with a deeply nested structure."""
def create_nested(depth):
if depth == 0:
@@ -124,15 +122,15 @@ def test_depth_limit():
"next": {
"next": {
"next": {
"next": "{'next': {'next': {'next': {'next': {'next': 'value'}}}}}"
}
}
}
}
"next": "{'next': {'next': {'next': {'next': {'next': 'value'}}}}}",
},
},
},
},
}
def test_exclude_keys():
def test_exclude_keys() -> None:
result = to_serializable({"key1": "value1", "key2": "value2"}, exclude={"key1"})
assert result == {"key2": "value2"}

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any
import pytest
@@ -8,10 +8,10 @@ from crewai.utilities.string_utils import interpolate_only
class TestInterpolateOnly:
"""Tests for the interpolate_only function in string_utils.py."""
def test_basic_variable_interpolation(self):
def test_basic_variable_interpolation(self) -> None:
"""Test basic variable interpolation works correctly."""
template = "Hello, {name}! Welcome to {company}."
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"name": "Alice",
"company": "CrewAI",
}
@@ -20,18 +20,18 @@ class TestInterpolateOnly:
assert result == "Hello, Alice! Welcome to CrewAI."
def test_multiple_occurrences_of_same_variable(self):
def test_multiple_occurrences_of_same_variable(self) -> None:
"""Test that multiple occurrences of the same variable are replaced."""
template = "{name} is using {name}'s account."
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"name": "Bob"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"name": "Bob",
}
result = interpolate_only(template, inputs)
assert result == "Bob is using Bob's account."
def test_json_structure_preservation(self):
def test_json_structure_preservation(self) -> None:
"""Test that JSON structures are preserved and not interpolated incorrectly."""
template = """
Instructions for {agent}:
@@ -40,8 +40,8 @@ class TestInterpolateOnly:
{"name": "person's name", "age": 25, "skills": ["coding", "testing"]}
"""
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"agent": "DevAgent"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"agent": "DevAgent",
}
result = interpolate_only(template, inputs)
@@ -52,7 +52,7 @@ class TestInterpolateOnly:
in result
)
def test_complex_nested_json(self):
def test_complex_nested_json(self) -> None:
"""Test with complex JSON structures containing curly braces."""
template = """
{agent} needs to process:
@@ -65,8 +65,8 @@ class TestInterpolateOnly:
}
}
"""
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"agent": "DataProcessor"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"agent": "DataProcessor",
}
result = interpolate_only(template, inputs)
@@ -76,11 +76,11 @@ class TestInterpolateOnly:
assert '"value": 42' in result
assert '[1, 2, {"inner": "value"}]' in result
def test_missing_variable(self):
def test_missing_variable(self) -> None:
"""Test that an error is raised when a required variable is missing."""
template = "Hello, {name}!"
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"not_name": "Alice"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"not_name": "Alice",
}
with pytest.raises(KeyError) as excinfo:
@@ -89,55 +89,55 @@ class TestInterpolateOnly:
assert "template variable" in str(excinfo.value).lower()
assert "name" in str(excinfo.value)
def test_invalid_input_types(self):
def test_invalid_input_types(self) -> None:
"""Test that an error is raised with invalid input types."""
template = "Hello, {name}!"
# Using Any for this test since we're intentionally testing an invalid type
inputs: Dict[str, Any] = {"name": object()} # Object is not a valid input type
inputs: dict[str, Any] = {"name": object()} # Object is not a valid input type
with pytest.raises(ValueError) as excinfo:
interpolate_only(template, inputs)
assert "unsupported type" in str(excinfo.value).lower()
def test_empty_input_string(self):
def test_empty_input_string(self) -> None:
"""Test handling of empty or None input string."""
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"name": "Alice"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"name": "Alice",
}
assert interpolate_only("", inputs) == ""
assert interpolate_only(None, inputs) == ""
def test_no_variables_in_template(self):
def test_no_variables_in_template(self) -> None:
"""Test a template with no variables to replace."""
template = "This is a static string with no variables."
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"name": "Alice"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"name": "Alice",
}
result = interpolate_only(template, inputs)
assert result == template
def test_variable_name_starting_with_underscore(self):
def test_variable_name_starting_with_underscore(self) -> None:
"""Test variables starting with underscore are replaced correctly."""
template = "Variable: {_special_var}"
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"_special_var": "Special Value"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"_special_var": "Special Value",
}
result = interpolate_only(template, inputs)
assert result == "Variable: Special Value"
def test_preserves_non_matching_braces(self):
def test_preserves_non_matching_braces(self) -> None:
"""Test that non-matching braces patterns are preserved."""
template = (
"This {123} and {!var} should not be replaced but {valid_var} should."
)
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
"valid_var": "works"
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"valid_var": "works",
}
result = interpolate_only(template, inputs)
@@ -146,15 +146,15 @@ class TestInterpolateOnly:
result == "This {123} and {!var} should not be replaced but works should."
)
def test_complex_mixed_scenario(self):
def test_complex_mixed_scenario(self) -> None:
"""Test a complex scenario with both valid variables and JSON structures."""
template = """
{agent_name} is working on task {task_id}.
Instructions:
1. Process the data
2. Return results as:
{
"taskId": "{task_id}",
"results": {
@@ -164,7 +164,7 @@ class TestInterpolateOnly:
}
}
"""
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] = {
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]] = {
"agent_name": "AnalyticsAgent",
"task_id": "T-12345",
}
@@ -176,10 +176,10 @@ class TestInterpolateOnly:
assert '"processed_by": "agent_name"' in result # This shouldn't be replaced
assert '"values": [1, 2, 3]' in result
def test_empty_inputs_dictionary(self):
def test_empty_inputs_dictionary(self) -> None:
"""Test that an error is raised with empty inputs dictionary."""
template = "Hello, {name}!"
inputs: Dict[str, Any] = {}
inputs: dict[str, Any] = {}
with pytest.raises(ValueError) as excinfo:
interpolate_only(template, inputs)

View File

@@ -5,14 +5,14 @@ from crewai.utilities.training_handler import CrewTrainingHandler
class InternalCrewTrainingHandler(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.handler = CrewTrainingHandler("trained_data.pkl")
def tearDown(self):
def tearDown(self) -> None:
os.remove("trained_data.pkl")
del self.handler
def test_save_trained_data(self):
def test_save_trained_data(self) -> None:
agent_id = "agent1"
trained_data = {"param1": 1, "param2": 2}
self.handler.save_trained_data(agent_id, trained_data)
@@ -21,7 +21,7 @@ class InternalCrewTrainingHandler(unittest.TestCase):
data = self.handler.load()
assert data[agent_id] == trained_data
def test_append_existing_agent(self):
def test_append_existing_agent(self) -> None:
train_iteration = 1
agent_id = "agent1"
new_data = {"param3": 3, "param4": 4}
@@ -31,7 +31,7 @@ class InternalCrewTrainingHandler(unittest.TestCase):
data = self.handler.load()
assert data[agent_id][train_iteration] == new_data
def test_append_new_agent(self):
def test_append_new_agent(self) -> None:
train_iteration = 1
agent_id = "agent2"
new_data = {"param5": 5, "param6": 6}