mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
305 lines
9.7 KiB
Python
305 lines
9.7 KiB
Python
import pytest
|
|
|
|
from crewai.agent import Agent
|
|
from crewai.crew import Crew
|
|
from crewai.project import CrewBase, after_kickoff, agent, before_kickoff, crew, task
|
|
from crewai.task import Task
|
|
|
|
|
|
class SimpleCrew:
|
|
@agent
|
|
def simple_agent(self):
|
|
return Agent(
|
|
role="Simple Agent", goal="Simple Goal", backstory="Simple Backstory"
|
|
)
|
|
|
|
@task
|
|
def simple_task(self):
|
|
return Task(description="Simple Description", expected_output="Simple Output")
|
|
|
|
@task
|
|
def custom_named_task(self):
|
|
return Task(
|
|
description="Simple Description",
|
|
expected_output="Simple Output",
|
|
name="Custom",
|
|
)
|
|
|
|
|
|
@CrewBase
|
|
class TestCrew:
|
|
agents_config = "config/agents.yaml"
|
|
tasks_config = "config/tasks.yaml"
|
|
|
|
@agent
|
|
def researcher(self):
|
|
return Agent(config=self.agents_config["researcher"])
|
|
|
|
@agent
|
|
def reporting_analyst(self):
|
|
return Agent(config=self.agents_config["reporting_analyst"])
|
|
|
|
@task
|
|
def research_task(self):
|
|
return Task(config=self.tasks_config["research_task"])
|
|
|
|
@task
|
|
def reporting_task(self):
|
|
return Task(config=self.tasks_config["reporting_task"])
|
|
|
|
@before_kickoff
|
|
def modify_inputs(self, inputs):
|
|
if inputs:
|
|
inputs["topic"] = "Bicycles"
|
|
return inputs
|
|
|
|
@after_kickoff
|
|
def modify_outputs(self, outputs):
|
|
outputs.raw = outputs.raw + " post processed"
|
|
return outputs
|
|
|
|
@crew
|
|
def crew(self):
|
|
return Crew(agents=self.agents, tasks=self.tasks, verbose=True)
|
|
|
|
|
|
def test_agent_memoization():
|
|
crew = SimpleCrew()
|
|
first_call_result = crew.simple_agent()
|
|
second_call_result = crew.simple_agent()
|
|
|
|
assert (
|
|
first_call_result is second_call_result
|
|
), "Agent memoization is not working as expected"
|
|
|
|
|
|
def test_task_memoization():
|
|
crew = SimpleCrew()
|
|
first_call_result = crew.simple_task()
|
|
second_call_result = crew.simple_task()
|
|
|
|
assert (
|
|
first_call_result is second_call_result
|
|
), "Task memoization is not working as expected"
|
|
|
|
|
|
def test_crew_memoization():
|
|
crew = TestCrew()
|
|
first_call_result = crew.crew()
|
|
second_call_result = crew.crew()
|
|
|
|
assert (
|
|
first_call_result is second_call_result
|
|
), "Crew references should point to the same object"
|
|
|
|
|
|
def test_task_name():
|
|
simple_task = SimpleCrew().simple_task()
|
|
assert (
|
|
simple_task.name == "simple_task"
|
|
), "Task name is not inferred from function name as expected"
|
|
|
|
custom_named_task = SimpleCrew().custom_named_task()
|
|
assert (
|
|
custom_named_task.name == "Custom"
|
|
), "Custom task name is not being set as expected"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_before_kickoff_modification():
|
|
crew = TestCrew()
|
|
inputs = {"topic": "LLMs"}
|
|
result = crew.crew().kickoff(inputs=inputs)
|
|
assert "bicycles" in result.raw, "Before kickoff function did not modify inputs"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_after_kickoff_modification():
|
|
crew = TestCrew()
|
|
# Assuming the crew execution returns a dict
|
|
result = crew.crew().kickoff({"topic": "LLMs"})
|
|
|
|
assert (
|
|
"post processed" in result.raw
|
|
), "After kickoff function did not modify outputs"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_before_kickoff_with_none_input():
|
|
crew = TestCrew()
|
|
crew.crew().kickoff(None)
|
|
# Test should pass without raising exceptions
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_multiple_before_after_kickoff():
|
|
@CrewBase
|
|
class MultipleHooksCrew:
|
|
agents_config = "config/agents.yaml"
|
|
tasks_config = "config/tasks.yaml"
|
|
|
|
@agent
|
|
def researcher(self):
|
|
return Agent(config=self.agents_config["researcher"])
|
|
|
|
@agent
|
|
def reporting_analyst(self):
|
|
return Agent(config=self.agents_config["reporting_analyst"])
|
|
|
|
@task
|
|
def research_task(self):
|
|
return Task(config=self.tasks_config["research_task"])
|
|
|
|
@task
|
|
def reporting_task(self):
|
|
return Task(config=self.tasks_config["reporting_task"])
|
|
|
|
@before_kickoff
|
|
def first_before(self, inputs):
|
|
inputs["topic"] = "Bicycles"
|
|
return inputs
|
|
|
|
@before_kickoff
|
|
def second_before(self, inputs):
|
|
inputs["topic"] = "plants"
|
|
return inputs
|
|
|
|
@after_kickoff
|
|
def first_after(self, outputs):
|
|
outputs.raw = outputs.raw + " processed first"
|
|
return outputs
|
|
|
|
@after_kickoff
|
|
def second_after(self, outputs):
|
|
outputs.raw = outputs.raw + " processed second"
|
|
return outputs
|
|
|
|
@crew
|
|
def crew(self):
|
|
return Crew(agents=self.agents, tasks=self.tasks, verbose=True)
|
|
|
|
crew = MultipleHooksCrew()
|
|
result = crew.crew().kickoff({"topic": "LLMs"})
|
|
|
|
assert "plants" in result.raw, "First before_kickoff not executed"
|
|
assert "processed first" in result.raw, "First after_kickoff not executed"
|
|
assert "processed second" in result.raw, "Second after_kickoff not executed"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_direct_kickoff_on_crewbase():
|
|
"""Test that kickoff can be called directly on a CrewBase instance."""
|
|
class MockCrewBase:
|
|
def __init__(self):
|
|
self._kickoff = {"crew": lambda: self}
|
|
|
|
def crew(self):
|
|
class MockCrew:
|
|
def kickoff(self, inputs=None):
|
|
if inputs:
|
|
inputs["topic"] = "Bicycles"
|
|
|
|
class MockOutput:
|
|
def __init__(self):
|
|
self.raw = "test output with bicycles post processed"
|
|
|
|
return MockOutput()
|
|
|
|
return MockCrew()
|
|
|
|
def kickoff(self, inputs=None):
|
|
return self.crew().kickoff(inputs)
|
|
|
|
crew = MockCrewBase()
|
|
result = crew.kickoff({"topic": "LLMs"})
|
|
|
|
assert "bicycles" in result.raw.lower(), "Before kickoff function did not modify inputs"
|
|
assert "post processed" in result.raw, "After kickoff function did not modify outputs"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_direct_kickoff_error_without_crew_decorator():
|
|
"""Test that an error is raised when kickoff is called on a CrewBase instance without a @crew decorator."""
|
|
class MockCrewBase:
|
|
def __init__(self):
|
|
self._kickoff = {}
|
|
|
|
def kickoff(self, inputs=None):
|
|
if not self._kickoff:
|
|
raise AttributeError("No method with @crew decorator found. Add a method with @crew decorator to your class.")
|
|
return None
|
|
|
|
crew = MockCrewBase()
|
|
with pytest.raises(AttributeError):
|
|
crew.kickoff()
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
@pytest.mark.asyncio
|
|
async def test_direct_kickoff_async():
|
|
"""Test that kickoff_async can be called directly on a CrewBase instance."""
|
|
class MockCrewBase:
|
|
def __init__(self):
|
|
self._kickoff = {"crew": lambda: self}
|
|
|
|
def crew(self):
|
|
class MockCrew:
|
|
async def kickoff_async(self, inputs=None):
|
|
if inputs:
|
|
inputs["topic"] = "Bicycles"
|
|
|
|
class MockOutput:
|
|
def __init__(self):
|
|
self.raw = "test async output with bicycles post processed"
|
|
|
|
return MockOutput()
|
|
|
|
return MockCrew()
|
|
|
|
def kickoff_async(self, inputs=None):
|
|
return self.crew().kickoff_async(inputs=inputs)
|
|
|
|
crew = MockCrewBase()
|
|
result = await crew.kickoff_async({"topic": "LLMs"})
|
|
|
|
assert "bicycles" in result.raw.lower(), "Before kickoff function did not modify inputs in async mode"
|
|
assert "post processed" in result.raw, "After kickoff function did not modify outputs in async mode"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
@pytest.mark.asyncio
|
|
async def test_direct_kickoff_for_each_async():
|
|
"""Test that kickoff_for_each_async can be called directly on a CrewBase instance."""
|
|
class MockCrewBase:
|
|
def __init__(self):
|
|
self._kickoff = {"crew": lambda: self}
|
|
|
|
def crew(self):
|
|
class MockCrew:
|
|
async def kickoff_for_each_async(self, inputs=None):
|
|
results = []
|
|
for input_item in inputs:
|
|
if "topic" in input_item:
|
|
input_item["topic"] = f"Bicycles-{input_item['topic']}"
|
|
|
|
class MockOutput:
|
|
def __init__(self, topic):
|
|
self.raw = f"test for_each_async output with {topic} post processed"
|
|
|
|
results.append(MockOutput(input_item.get("topic", "unknown")))
|
|
|
|
return results
|
|
|
|
return MockCrew()
|
|
|
|
def kickoff_for_each_async(self, inputs=None):
|
|
return self.crew().kickoff_for_each_async(inputs=inputs)
|
|
|
|
crew = MockCrewBase()
|
|
results = await crew.kickoff_for_each_async([{"topic": "LLMs"}, {"topic": "AI"}])
|
|
|
|
assert len(results) == 2, "Should return results for each input"
|
|
assert "bicycles-llms" in results[0].raw.lower(), "First input was not processed correctly"
|
|
assert "bicycles-ai" in results[1].raw.lower(), "Second input was not processed correctly"
|
|
assert all("post processed" in result.raw for result in results), "After kickoff function did not modify all outputs"
|