mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
add tests for new code
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import logging # Import logging module
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, TypeVar, cast
|
||||
|
||||
@@ -7,12 +8,17 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging to display warnings
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
|
||||
"""Base decorator for creating crew classes with configuration and function management."""
|
||||
|
||||
|
||||
def CrewBase(cls: T) -> T:
|
||||
"""Wraps a class with crew functionality and configuration management."""
|
||||
|
||||
class WrappedClass(cls): # type: ignore
|
||||
is_crew_class: bool = True # type: ignore
|
||||
|
||||
@@ -27,11 +33,41 @@ def CrewBase(cls: T) -> T:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
agents_config_path = self.base_directory / self.original_agents_config_path
|
||||
tasks_config_path = self.base_directory / self.original_tasks_config_path
|
||||
if isinstance(self.original_agents_config_path, str):
|
||||
agents_config_path = (
|
||||
self.base_directory / self.original_agents_config_path
|
||||
)
|
||||
try:
|
||||
self.agents_config = self.load_yaml(agents_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Agent config file not found at {agents_config_path}. "
|
||||
"Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
|
||||
self.agents_config = self.load_yaml(agents_config_path)
|
||||
self.tasks_config = self.load_yaml(tasks_config_path)
|
||||
if isinstance(self.original_tasks_config_path, str):
|
||||
tasks_config_path = (
|
||||
self.base_directory / self.original_tasks_config_path
|
||||
)
|
||||
try:
|
||||
self.tasks_config = self.load_yaml(tasks_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Task config file not found at {tasks_config_path}. "
|
||||
"Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No task configuration path provided. Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
|
||||
self.map_all_agent_variables()
|
||||
self.map_all_task_variables()
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.process import Process
|
||||
from crewai.project import crew
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
@@ -3474,3 +3475,115 @@ def test_crew_guardrail_feedback_in_context():
|
||||
|
||||
# Verify task retry count
|
||||
assert task.retry_count == 1, "Task should have been retried once"
|
||||
|
||||
|
||||
def test_before_kickoff_callback():
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
|
||||
@CrewBase
|
||||
class TestCrewClass:
|
||||
agents_config = None
|
||||
tasks_config = None
|
||||
|
||||
def __init__(self):
|
||||
self.inputs_modified = False
|
||||
|
||||
@before_kickoff
|
||||
def modify_inputs(self, inputs):
|
||||
|
||||
self.inputs_modified = True
|
||||
inputs["modified"] = True
|
||||
return inputs
|
||||
|
||||
@agent
|
||||
def my_agent(self):
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test agent goal",
|
||||
backstory="Test agent backstory",
|
||||
)
|
||||
|
||||
@task
|
||||
def my_task(self):
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=self.my_agent(), # Use the agent instance
|
||||
)
|
||||
return task
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=self.tasks)
|
||||
|
||||
test_crew_instance = TestCrewClass().crew()
|
||||
|
||||
# Verify that the before_kickoff_callbacks are set
|
||||
assert len(crew.before_kickoff_callbacks) == 1
|
||||
|
||||
# Prepare inputs
|
||||
inputs = {"initial": True}
|
||||
|
||||
# Call kickoff
|
||||
test_crew_instance.kickoff(inputs=inputs)
|
||||
|
||||
# Check that the before_kickoff function was called and modified inputs
|
||||
assert test_crew_instance.inputs_modified
|
||||
assert inputs.get("modified") == True
|
||||
|
||||
|
||||
def test_before_kickoff_without_inputs():
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
|
||||
@CrewBase
|
||||
class TestCrewClass:
|
||||
agents_config = None
|
||||
tasks_config = None
|
||||
|
||||
def __init__(self):
|
||||
self.inputs_modified = False
|
||||
self.received_inputs = None
|
||||
|
||||
@before_kickoff
|
||||
def modify_inputs(self, inputs):
|
||||
self.inputs_modified = True
|
||||
inputs["modified"] = True
|
||||
self.received_inputs = inputs
|
||||
return inputs
|
||||
|
||||
@agent
|
||||
def my_agent(self):
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test agent goal",
|
||||
backstory="Test agent backstory",
|
||||
)
|
||||
|
||||
@task
|
||||
def my_task(self):
|
||||
return Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=self.my_agent(),
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=self.tasks)
|
||||
|
||||
# Instantiate the class
|
||||
test_crew_instance = TestCrewClass()
|
||||
# Build the crew
|
||||
crew = test_crew_instance.crew()
|
||||
# Verify that the before_kickoff_callback is registered
|
||||
assert len(crew.before_kickoff_callbacks) == 1
|
||||
|
||||
# Call kickoff without passing inputs
|
||||
output = crew.kickoff()
|
||||
|
||||
# Check that the before_kickoff function was called
|
||||
assert test_crew_instance.inputs_modified
|
||||
|
||||
# Verify that the inputs were initialized and modified inside the before_kickoff method
|
||||
assert test_crew_instance.received_inputs is not None
|
||||
assert test_crew_instance.received_inputs.get("modified") is True
|
||||
|
||||
Reference in New Issue
Block a user