diff --git a/src/crewai/agents/agent_builder/utilities/base_token_process.py b/src/crewai/agents/agent_builder/utilities/base_token_process.py index 322fade0e..3ce5cfb82 100644 --- a/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -2,26 +2,26 @@ from crewai.types.usage_metrics import UsageMetrics class TokenProcess: - def __init__(self): + def __init__(self) -> None: self.total_tokens: int = 0 self.prompt_tokens: int = 0 self.cached_prompt_tokens: int = 0 self.completion_tokens: int = 0 self.successful_requests: int = 0 - def sum_prompt_tokens(self, tokens: int): - self.prompt_tokens = self.prompt_tokens + tokens - self.total_tokens = self.total_tokens + tokens + def sum_prompt_tokens(self, tokens: int) -> None: + self.prompt_tokens += tokens + self.total_tokens += tokens - def sum_completion_tokens(self, tokens: int): - self.completion_tokens = self.completion_tokens + tokens - self.total_tokens = self.total_tokens + tokens + def sum_completion_tokens(self, tokens: int) -> None: + self.completion_tokens += tokens + self.total_tokens += tokens - def sum_cached_prompt_tokens(self, tokens: int): - self.cached_prompt_tokens = self.cached_prompt_tokens + tokens + def sum_cached_prompt_tokens(self, tokens: int) -> None: + self.cached_prompt_tokens += tokens - def sum_successful_requests(self, requests: int): - self.successful_requests = self.successful_requests + requests + def sum_successful_requests(self, requests: int) -> None: + self.successful_requests += requests def get_summary(self) -> UsageMetrics: return UsageMetrics( diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index ef688b9c1..ad91d1236 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,6 +1,5 @@ import asyncio import inspect -import uuid from typing import ( Any, Callable, @@ -13,7 +12,6 @@ from typing import ( TypeVar, Union, cast, - overload, ) from uuid import uuid4 @@ -27,7 +25,6 @@ from crewai.flow.flow_events import ( MethodExecutionStartedEvent, ) from crewai.flow.flow_visualizer import plot_flow -from crewai.flow.persistence import FlowPersistence from crewai.flow.persistence.base import FlowPersistence from crewai.flow.utils import get_possible_return_constants from crewai.telemetry import Telemetry @@ -35,22 +32,32 @@ from crewai.telemetry import Telemetry class FlowState(BaseModel): """Base model for all flow states, ensuring each state has a unique ID.""" - id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state") + + id: str = Field( + default_factory=lambda: str(uuid4()), + description="Unique identifier for the flow state", + ) + # Type variables with explicit bounds -T = TypeVar("T", bound=Union[Dict[str, Any], BaseModel]) # Generic flow state type parameter -StateT = TypeVar("StateT", bound=Union[Dict[str, Any], BaseModel]) # State validation type parameter +T = TypeVar( + "T", bound=Union[Dict[str, Any], BaseModel] +) # Generic flow state type parameter +StateT = TypeVar( + "StateT", bound=Union[Dict[str, Any], BaseModel] +) # State validation type parameter + def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: """Ensure state matches expected type with proper validation. - + Args: state: State instance to validate expected_type: Expected type for the state - + Returns: Validated state instance - + Raises: TypeError: If state doesn't match expected type ValueError: If state validation fails @@ -68,13 +75,15 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: TypeError: If state doesn't match expected type ValueError: If state validation fails """ - if expected_type == dict: + if expected_type is dict: if not isinstance(state, dict): raise TypeError(f"Expected dict, got {type(state).__name__}") return cast(StateT, state) if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): if not isinstance(state, expected_type): - raise TypeError(f"Expected {expected_type.__name__}, got {type(state).__name__}") + raise TypeError( + f"Expected {expected_type.__name__}, got {type(state).__name__}" + ) return cast(StateT, state) raise TypeError(f"Invalid expected_type: {expected_type}") @@ -120,6 +129,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: >>> def complex_start(self): ... pass """ + def decorator(func): func.__is_start_method__ = True if condition is not None: @@ -144,6 +154,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: return decorator + def listen(condition: Union[str, dict, Callable]) -> Callable: """ Creates a listener that executes when specified conditions are met. @@ -180,6 +191,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: >>> def handle_completion(self): ... pass """ + def decorator(func): if isinstance(condition, str): func.__trigger_methods__ = [condition] @@ -244,6 +256,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable: ... return CONTINUE ... return STOP """ + def decorator(func): func.__is_router__ = True if isinstance(condition, str): @@ -267,6 +280,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable: return decorator + def or_(*conditions: Union[str, dict, Callable]) -> dict: """ Combines multiple conditions with OR logic for flow control. @@ -370,22 +384,27 @@ class FlowMeta(type): for attr_name, attr_value in dct.items(): # Check for any flow-related attributes - if (hasattr(attr_value, "__is_flow_method__") or - hasattr(attr_value, "__is_start_method__") or - hasattr(attr_value, "__trigger_methods__") or - hasattr(attr_value, "__is_router__")): - + if ( + hasattr(attr_value, "__is_flow_method__") + or hasattr(attr_value, "__is_start_method__") + or hasattr(attr_value, "__trigger_methods__") + or hasattr(attr_value, "__is_router__") + ): + # Register start methods if hasattr(attr_value, "__is_start_method__"): start_methods.append(attr_name) - + # Register listeners and routers if hasattr(attr_value, "__trigger_methods__"): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) - - if hasattr(attr_value, "__is_router__") and attr_value.__is_router__: + + if ( + hasattr(attr_value, "__is_router__") + and attr_value.__is_router__ + ): routers.add(attr_name) possible_returns = get_possible_return_constants(attr_value) if possible_returns: @@ -401,8 +420,9 @@ class FlowMeta(type): class Flow(Generic[T], metaclass=FlowMeta): """Base class for all flows. - + Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.""" + _telemetry = Telemetry() _start_methods: List[str] = [] @@ -426,7 +446,7 @@ class Flow(Generic[T], metaclass=FlowMeta): **kwargs: Any, ) -> None: """Initialize a new Flow instance. - + Args: persistence: Optional persistence backend for storing flow states restore_uuid: Optional UUID to restore state from persistence @@ -438,29 +458,38 @@ class Flow(Generic[T], metaclass=FlowMeta): self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs self._persistence: Optional[FlowPersistence] = persistence - + # Validate state model before initialization if isinstance(self.initial_state, type): - if issubclass(self.initial_state, BaseModel) and not issubclass(self.initial_state, FlowState): + if issubclass(self.initial_state, BaseModel) and not issubclass( + self.initial_state, FlowState + ): # Check if model has id field model_fields = getattr(self.initial_state, "model_fields", None) if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") - + # Handle persistence and potential ID conflicts stored_state = None if self._persistence is not None: - if restore_uuid and kwargs and "id" in kwargs and restore_uuid != kwargs["id"]: + if ( + restore_uuid + and kwargs + and "id" in kwargs + and restore_uuid != kwargs["id"] + ): raise ValueError( f"Conflicting IDs provided: restore_uuid='{restore_uuid}' " f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration." ) - + # Attempt to load state, prioritizing restore_uuid if restore_uuid: stored_state = self._persistence.load_state(restore_uuid) if not stored_state: - raise ValueError(f"No state found for restore_uuid='{restore_uuid}'") + raise ValueError( + f"No state found for restore_uuid='{restore_uuid}'" + ) elif kwargs and "id" in kwargs: stored_state = self._persistence.load_state(kwargs["id"]) if not stored_state: @@ -469,7 +498,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if kwargs: self._initialize_state(kwargs) return - + # Initialize state based on persistence and kwargs if stored_state: # Create initial state and restore from persistence @@ -494,23 +523,23 @@ class Flow(Generic[T], metaclass=FlowMeta): if not method_name.startswith("_"): method = getattr(self, method_name) # Check for any flow-related attributes - if (hasattr(method, "__is_flow_method__") or - hasattr(method, "__is_start_method__") or - hasattr(method, "__trigger_methods__") or - hasattr(method, "__is_router__")): + if ( + hasattr(method, "__is_flow_method__") + or hasattr(method, "__is_start_method__") + or hasattr(method, "__trigger_methods__") + or hasattr(method, "__is_router__") + ): # Ensure method is bound to this instance if not hasattr(method, "__self__"): method = method.__get__(self, self.__class__) self._methods[method_name] = method - - def _create_initial_state(self) -> T: """Create and initialize flow state with UUID and default values. - + Returns: New state instance with UUID and default values initialized - + Raises: ValueError: If structured state model lacks 'id' field TypeError: If state is neither BaseModel nor dictionary @@ -522,24 +551,25 @@ class Flow(Generic[T], metaclass=FlowMeta): if issubclass(state_type, FlowState): # Create instance without id, then set it instance = state_type() - if not hasattr(instance, 'id'): - setattr(instance, 'id', str(uuid4())) + if not hasattr(instance, "id"): + setattr(instance, "id", str(uuid4())) return cast(T, instance) elif issubclass(state_type, BaseModel): # Create a new type that includes the ID field class StateWithId(state_type, FlowState): # type: ignore pass + instance = StateWithId() - if not hasattr(instance, 'id'): - setattr(instance, 'id', str(uuid4())) + if not hasattr(instance, "id"): + setattr(instance, "id", str(uuid4())) return cast(T, instance) - elif state_type == dict: - return cast(T, {"id": str(uuid4())}) # Minimal dict state - + elif state_type is dict: + return cast(T, {"id": str(uuid4())}) + # Handle case where no initial state is provided if self.initial_state is None: return cast(T, {"id": str(uuid4())}) - + # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): @@ -550,22 +580,22 @@ class Flow(Generic[T], metaclass=FlowMeta): if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") return cast(T, self.initial_state()) # Uses model defaults - elif self.initial_state == dict: + elif self.initial_state is dict: return cast(T, {"id": str(uuid4())}) - + # Handle dictionary instance case if isinstance(self.initial_state, dict): new_state = dict(self.initial_state) # Copy to avoid mutations if "id" not in new_state: new_state["id"] = str(uuid4()) return cast(T, new_state) - + # Handle BaseModel instance case if isinstance(self.initial_state, BaseModel): model = cast(BaseModel, self.initial_state) if not hasattr(model, "id"): raise ValueError("Flow state model must have an 'id' field") - + # Create new instance with same values to avoid mutations if hasattr(model, "model_dump"): # Pydantic v2 @@ -576,60 +606,13 @@ class Flow(Generic[T], metaclass=FlowMeta): else: # Fallback for other BaseModel implementations state_dict = { - k: v for k, v in model.__dict__.items() - if not k.startswith("_") + k: v for k, v in model.__dict__.items() if not k.startswith("_") } - + # Create new instance of the same class model_class = type(model) return cast(T, model_class(**state_dict)) - - raise TypeError( - f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" - ) - # Handle case where initial_state is None but we have a type parameter - if self.initial_state is None and hasattr(self, "_initial_state_T"): - state_type = getattr(self, "_initial_state_T") - if isinstance(state_type, type): - if issubclass(state_type, FlowState): - return cast(T, state_type()) - elif issubclass(state_type, BaseModel): - # Create a new type that includes the ID field - class StateWithId(state_type, FlowState): # type: ignore - pass - return cast(T, StateWithId()) - elif state_type == dict: - return cast(T, {"id": str(uuid4())}) - # Handle case where no initial state is provided - if self.initial_state is None: - return cast(T, {"id": str(uuid4())}) - - # Handle case where initial_state is a type (class) - if isinstance(self.initial_state, type): - if issubclass(self.initial_state, FlowState): - return cast(T, self.initial_state()) - elif issubclass(self.initial_state, BaseModel): - # Validate that the model has an id field - model_fields = getattr(self.initial_state, "model_fields", None) - if not model_fields or "id" not in model_fields: - raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state()) - elif self.initial_state == dict: - return cast(T, {"id": str(uuid4())}) - - # Handle dictionary instance case - if isinstance(self.initial_state, dict): - if "id" not in self.initial_state: - self.initial_state["id"] = str(uuid4()) - return cast(T, dict(self.initial_state)) # Create new dict to avoid mutations - - # Handle BaseModel instance case - if isinstance(self.initial_state, BaseModel): - if not hasattr(self.initial_state, "id"): - raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state) - raise TypeError( f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) @@ -645,10 +628,10 @@ class Flow(Generic[T], metaclass=FlowMeta): def _initialize_state(self, inputs: Dict[str, Any]) -> None: """Initialize or update flow state with new inputs. - + Args: inputs: Dictionary of state values to set/update - + Raises: ValueError: If validation fails for structured state TypeError: If state is neither BaseModel nor dictionary @@ -675,13 +658,12 @@ class Flow(Generic[T], metaclass=FlowMeta): current_state = model.dict() else: current_state = { - k: v for k, v in model.__dict__.items() - if not k.startswith("_") + k: v for k, v in model.__dict__.items() if not k.startswith("_") } - + # Create new state with preserved fields and updates new_state = {**current_state, **inputs} - + # Create new instance with merged state model_class = type(model) if hasattr(model_class, "model_validate"): @@ -697,13 +679,13 @@ class Flow(Generic[T], metaclass=FlowMeta): raise ValueError(f"Invalid inputs for structured state: {e}") from e else: raise TypeError("State must be a BaseModel instance or a dictionary.") - + def _restore_state(self, stored_state: Dict[str, Any]) -> None: """Restore flow state from persistence. - + Args: stored_state: Previously stored state to restore - + Raises: ValueError: If validation fails for structured state TypeError: If state is neither BaseModel nor dictionary @@ -712,7 +694,7 @@ class Flow(Generic[T], metaclass=FlowMeta): stored_id = stored_state.get("id") if not stored_id: raise ValueError("Stored state must have an 'id' field") - + if isinstance(self._state, dict): # For dict states, update all fields from stored state self._state.clear() @@ -730,9 +712,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # Fallback for other BaseModel implementations self._state = cast(T, type(model)(**stored_state)) else: - raise TypeError( - f"State must be dict or BaseModel, got {type(self._state)}" - ) + raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: self.event_emitter.send( diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 13230edf6..8c0287509 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -24,10 +24,12 @@ def create_llm( # 1) If llm_value is already an LLM object, return it directly if isinstance(llm_value, LLM): + print("LLM value is already an LLM object") return llm_value # 2) If llm_value is a string (model name) if isinstance(llm_value, str): + print("LLM value is a string") try: created_llm = LLM(model=llm_value) return created_llm @@ -37,10 +39,12 @@ def create_llm( # 3) If llm_value is None, parse environment variables or use default if llm_value is None: + print("LLM value is None") return _llm_via_environment_or_fallback() # 4) Otherwise, attempt to extract relevant attributes from an unknown object try: + print("LLM value is an unknown object") # Extract attributes with explicit types model = ( getattr(llm_value, "model_name", None) diff --git a/tests/agent_test.py b/tests/agent_test.py index 9df80141a..859a5d693 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -114,35 +114,6 @@ def test_custom_llm_temperature_preservation(): assert agent.llm.temperature == 0.7 -@pytest.mark.vcr(filter_headers=["authorization"]) -def test_agent_execute_task(): - from langchain_openai import ChatOpenAI - - from crewai import Task - - agent = Agent( - role="Math Tutor", - goal="Solve math problems accurately", - backstory="You are an experienced math tutor with a knack for explaining complex concepts simply.", - llm=ChatOpenAI(temperature=0.7, model="gpt-4o-mini"), - ) - - task = Task( - description="Calculate the area of a circle with radius 5 cm.", - expected_output="The calculated area of the circle in square centimeters.", - agent=agent, - ) - - result = agent.execute_task(task) - - assert result is not None - assert ( - result - == "The calculated area of the circle is approximately 78.5 square centimeters." - ) - assert "square centimeters" in result.lower() - - @pytest.mark.vcr(filter_headers=["authorization"]) def test_agent_execution(): agent = Agent( diff --git a/tests/cassettes/test_agent_execute_task.yaml b/tests/cassettes/test_agent_execute_task.yaml deleted file mode 100644 index d390b176d..000000000 --- a/tests/cassettes/test_agent_execute_task.yaml +++ /dev/null @@ -1,121 +0,0 @@ -interactions: -- request: - body: '{"messages": [{"role": "system", "content": "You are Math Tutor. You are - an experienced math tutor with a knack for explaining complex concepts simply.\nYour - personal goal is: Solve math problems accurately\nTo give my best complete final - answer to the task use the exact following format:\n\nThought: I now can give - a great answer\nFinal Answer: Your final answer must be the great and the most - complete as possible, it must be outcome described.\n\nI MUST use these formats, - my job depends on it!"}, {"role": "user", "content": "\nCurrent Task: Calculate - the area of a circle with radius 5 cm.\n\nThis is the expect criteria for your - final answer: The calculated area of the circle in square centimeters.\nyou - MUST return the actual complete content as the final answer, not a summary.\n\nBegin! - This is VERY important to you, use the tools available and give your best Final - Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "temperature": - 0.7}' - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '969' - content-type: - - application/json - host: - - api.openai.com - user-agent: - - OpenAI/Python 1.47.0 - x-stainless-arch: - - arm64 - x-stainless-async: - - 'false' - x-stainless-lang: - - python - x-stainless-os: - - MacOS - x-stainless-package-version: - - 1.47.0 - x-stainless-raw-response: - - 'true' - x-stainless-runtime: - - CPython - x-stainless-runtime-version: - - 3.11.7 - method: POST - uri: https://api.openai.com/v1/chat/completions - response: - content: "{\n \"id\": \"chatcmpl-AB7LEfa5gX4cncpI4avsK0CJG8pCb\",\n \"object\": - \"chat.completion\",\n \"created\": 1727213192,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": \"I now can give a great answer\\n\\nTo - calculate the area of a circle, we use the formula:\\n\\n\\\\[ A = \\\\pi r^2 - \\\\]\\n\\nwhere \\\\( A \\\\) is the area, \\\\( \\\\pi \\\\) (approximately - 3.14), and \\\\( r \\\\) is the radius of the circle.\\n\\nGiven that the radius - \\\\( r \\\\) is 5 cm, we can substitute this value into the formula:\\n\\n\\\\[ - A = \\\\pi (5 \\\\, \\\\text{cm})^2 \\\\]\\n\\nCalculating this step-by-step:\\n\\n1. - First, square the radius:\\n \\\\[ (5 \\\\, \\\\text{cm})^2 = 25 \\\\, \\\\text{cm}^2 - \\\\]\\n\\n2. Then, multiply by \\\\( \\\\pi \\\\):\\n \\\\[ A = \\\\pi \\\\times - 25 \\\\, \\\\text{cm}^2 \\\\]\\n\\nUsing the approximate value of \\\\( \\\\pi - \\\\):\\n \\\\[ A \\\\approx 3.14 \\\\times 25 \\\\, \\\\text{cm}^2 \\\\]\\n - \ \\\\[ A \\\\approx 78.5 \\\\, \\\\text{cm}^2 \\\\]\\n\\nThus, the area of - the circle is approximately 78.5 square centimeters.\\n\\nFinal Answer: The - calculated area of the circle is approximately 78.5 square centimeters.\",\n - \ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\": - \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 182,\n \"completion_tokens\": - 270,\n \"total_tokens\": 452,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": - 0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" - headers: - CF-Cache-Status: - - DYNAMIC - CF-RAY: - - 8c85da71fcac1cf3-GRU - Connection: - - keep-alive - Content-Encoding: - - gzip - Content-Type: - - application/json - Date: - - Tue, 24 Sep 2024 21:26:34 GMT - Server: - - cloudflare - Set-Cookie: - - __cf_bm=rb61BZH2ejzD5YPmLaEJqI7km71QqyNJGTVdNxBq6qk-1727213194-1.0.1.1-pJ49onmgX9IugEMuYQMralzD7oj_6W.CHbSu4Su1z3NyjTGYg.rhgJZWng8feFYah._oSnoYlkTjpK1Wd2C9FA; - path=/; expires=Tue, 24-Sep-24 21:56:34 GMT; domain=.api.openai.com; HttpOnly; - Secure; SameSite=None - - _cfuvid=lbRdAddVWV6W3f5Dm9SaOPWDUOxqtZBSPr_fTW26nEA-1727213194587-0.0.1.1-604800000; - path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None - Transfer-Encoding: - - chunked - X-Content-Type-Options: - - nosniff - access-control-expose-headers: - - X-Request-ID - openai-organization: - - crewai-iuxna1 - openai-processing-ms: - - '2244' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - x-ratelimit-limit-requests: - - '30000' - x-ratelimit-limit-tokens: - - '150000000' - x-ratelimit-remaining-requests: - - '29999' - x-ratelimit-remaining-tokens: - - '149999774' - x-ratelimit-reset-requests: - - 2ms - x-ratelimit-reset-tokens: - - 0s - x-request-id: - - req_2e565b5f24c38968e4e923a47ecc6233 - http_version: HTTP/1.1 - status_code: 200 -version: 1 diff --git a/tests/crew_test.py b/tests/crew_test.py index 74a659738..2f347a50e 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -3480,10 +3480,12 @@ def test_crew_guardrail_feedback_in_context(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_before_kickoff_callback(): - from crewai.project import CrewBase, agent, before_kickoff, crew, task + from crewai.project import CrewBase, agent, before_kickoff, task @CrewBase class TestCrewClass: + from crewai.project import crew + agents_config = None tasks_config = None @@ -3510,7 +3512,7 @@ def test_before_kickoff_callback(): task = Task( description="Test task description", expected_output="Test expected output", - agent=self.my_agent(), # Use the agent instance + agent=self.my_agent(), ) return task @@ -3520,28 +3522,30 @@ def test_before_kickoff_callback(): test_crew_instance = TestCrewClass() - crew = test_crew_instance.crew() + test_crew = test_crew_instance.crew() # Verify that the before_kickoff_callbacks are set - assert len(crew.before_kickoff_callbacks) == 1 + assert len(test_crew.before_kickoff_callbacks) == 1 # Prepare inputs inputs = {"initial": True} # Call kickoff - crew.kickoff(inputs=inputs) + test_crew.kickoff(inputs=inputs) # Check that the before_kickoff function was called and modified inputs assert test_crew_instance.inputs_modified - assert inputs.get("modified") == True + assert inputs.get("modified") @pytest.mark.vcr(filter_headers=["authorization"]) def test_before_kickoff_without_inputs(): - from crewai.project import CrewBase, agent, before_kickoff, crew, task + from crewai.project import CrewBase, agent, before_kickoff, task @CrewBase class TestCrewClass: + from crewai.project import crew + agents_config = None tasks_config = None @@ -3579,12 +3583,12 @@ def test_before_kickoff_without_inputs(): # Instantiate the class test_crew_instance = TestCrewClass() # Build the crew - crew = test_crew_instance.crew() + test_crew = test_crew_instance.crew() # Verify that the before_kickoff_callback is registered - assert len(crew.before_kickoff_callbacks) == 1 + assert len(test_crew.before_kickoff_callbacks) == 1 # Call kickoff without passing inputs - output = crew.kickoff() + test_crew.kickoff() # Check that the before_kickoff function was called assert test_crew_instance.inputs_modified