diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 49a690093..ef149bfcc 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -330,4 +330,4 @@ This will clear the crew's memory, allowing for a fresh start. ## Deploying Your Project -The easiest way to deploy your crew is through [CrewAI Enterprise](https://www.crewai.com/crewaiplus), where you can deploy your crew in a few clicks. +The easiest way to deploy your crew is through [CrewAI Enterprise](http://app.crewai.com/), where you can deploy your crew in a few clicks. diff --git a/src/crewai/agent.py b/src/crewai/agent.py index ea4231eb3..a3e7fd467 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,6 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor +from crewai.cli.constants import ENV_VARS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM @@ -140,8 +141,12 @@ class Agent(BaseAgent): # If it's already an LLM instance, keep it as is pass elif self.llm is None: - # If it's None, use environment variables or default - model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini") + # Determine the model name from environment variables or use default + model_name = ( + os.environ.get("OPENAI_MODEL_NAME") + or os.environ.get("MODEL") + or "gpt-4o-mini" + ) llm_params = {"model": model_name} api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get( @@ -150,9 +155,39 @@ class Agent(BaseAgent): if api_base: llm_params["base_url"] = api_base - api_key = os.environ.get("OPENAI_API_KEY") - if api_key: - llm_params["api_key"] = api_key + # Iterate over all environment variables to find matching API keys or use defaults + for provider, env_vars in ENV_VARS.items(): + for env_var in env_vars: + # Check if the environment variable is set + if "key_name" in env_var: + env_value = os.environ.get(env_var["key_name"]) + if env_value: + # Map key names containing "API_KEY" to "api_key" + key_name = ( + "api_key" + if "API_KEY" in env_var["key_name"] + else env_var["key_name"] + ) + # Map key names containing "API_BASE" to "api_base" + key_name = ( + "api_base" + if "API_BASE" in env_var["key_name"] + else key_name + ) + # Map key names containing "API_VERSION" to "api_version" + key_name = ( + "api_version" + if "API_VERSION" in env_var["key_name"] + else key_name + ) + llm_params[key_name] = env_value + # Check for default values if the environment variable is not set + elif env_var.get("default", False): + for key, value in env_var.items(): + if key not in ["prompt", "key_name", "default"]: + # Only add default if the key is already set in os.environ + if key in os.environ: + llm_params[key] = value self.llm = LLM(**llm_params) else: diff --git a/src/crewai/cli/add_crew_to_flow.py b/src/crewai/cli/add_crew_to_flow.py index e4901fa89..ef693a22b 100644 --- a/src/crewai/cli/add_crew_to_flow.py +++ b/src/crewai/cli/add_crew_to_flow.py @@ -54,7 +54,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None: templates_dir = Path(__file__).parent / "templates" / "crew" config_template_files = ["agents.yaml", "tasks.yaml"] - crew_template_file = f"{folder_name}_crew.py" # Updated file name + crew_template_file = f"{folder_name}.py" # Updated file name for file_name in config_template_files: src_file = templates_dir / "config" / file_name diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 9a0b36c39..4be08fa2a 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -1,19 +1,168 @@ ENV_VARS = { - 'openai': ['OPENAI_API_KEY'], - 'anthropic': ['ANTHROPIC_API_KEY'], - 'gemini': ['GEMINI_API_KEY'], - 'groq': ['GROQ_API_KEY'], - 'ollama': ['FAKE_KEY'], + "openai": [ + { + "prompt": "Enter your OPENAI API key (press Enter to skip)", + "key_name": "OPENAI_API_KEY", + } + ], + "anthropic": [ + { + "prompt": "Enter your ANTHROPIC API key (press Enter to skip)", + "key_name": "ANTHROPIC_API_KEY", + } + ], + "gemini": [ + { + "prompt": "Enter your GEMINI API key (press Enter to skip)", + "key_name": "GEMINI_API_KEY", + } + ], + "groq": [ + { + "prompt": "Enter your GROQ API key (press Enter to skip)", + "key_name": "GROQ_API_KEY", + } + ], + "watson": [ + { + "prompt": "Enter your WATSONX URL (press Enter to skip)", + "key_name": "WATSONX_URL", + }, + { + "prompt": "Enter your WATSONX API Key (press Enter to skip)", + "key_name": "WATSONX_APIKEY", + }, + { + "prompt": "Enter your WATSONX Project Id (press Enter to skip)", + "key_name": "WATSONX_PROJECT_ID", + }, + ], + "ollama": [ + { + "default": True, + "API_BASE": "http://localhost:11434", + } + ], + "bedrock": [ + { + "prompt": "Enter your AWS Access Key ID (press Enter to skip)", + "key_name": "AWS_ACCESS_KEY_ID", + }, + { + "prompt": "Enter your AWS Secret Access Key (press Enter to skip)", + "key_name": "AWS_SECRET_ACCESS_KEY", + }, + { + "prompt": "Enter your AWS Region Name (press Enter to skip)", + "key_name": "AWS_REGION_NAME", + }, + ], + "azure": [ + { + "prompt": "Enter your Azure deployment name (must start with 'azure/')", + "key_name": "model", + }, + { + "prompt": "Enter your AZURE API key (press Enter to skip)", + "key_name": "AZURE_API_KEY", + }, + { + "prompt": "Enter your AZURE API base URL (press Enter to skip)", + "key_name": "AZURE_API_BASE", + }, + { + "prompt": "Enter your AZURE API version (press Enter to skip)", + "key_name": "AZURE_API_VERSION", + }, + ], + "cerebras": [ + { + "prompt": "Enter your Cerebras model name (must start with 'cerebras/')", + "key_name": "model", + }, + { + "prompt": "Enter your Cerebras API version (press Enter to skip)", + "key_name": "CEREBRAS_API_KEY", + }, + ], } -PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] + +PROVIDERS = [ + "openai", + "anthropic", + "gemini", + "groq", + "ollama", + "watson", + "bedrock", + "azure", + "cerebras", +] MODELS = { - 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], - 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], - 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], - 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], - 'ollama': ['llama3.1', 'mixtral'], + "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"], + "anthropic": [ + "claude-3-5-sonnet-20240620", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + ], + "gemini": [ + "gemini/gemini-1.5-flash", + "gemini/gemini-1.5-pro", + "gemini/gemini-gemma-2-9b-it", + "gemini/gemini-gemma-2-27b-it", + ], + "groq": [ + "groq/llama-3.1-8b-instant", + "groq/llama-3.1-70b-versatile", + "groq/llama-3.1-405b-reasoning", + "groq/gemma2-9b-it", + "groq/gemma-7b-it", + ], + "ollama": ["ollama/llama3.1", "ollama/mixtral"], + "watson": [ + "watsonx/google/flan-t5-xxl", + "watsonx/google/flan-ul2", + "watsonx/bigscience/mt0-xxl", + "watsonx/eleutherai/gpt-neox-20b", + "watsonx/ibm/mpt-7b-instruct2", + "watsonx/bigcode/starcoder", + "watsonx/meta-llama/llama-2-70b-chat", + "watsonx/meta-llama/llama-2-13b-chat", + "watsonx/ibm/granite-13b-instruct-v1", + "watsonx/ibm/granite-13b-chat-v1", + "watsonx/google/flan-t5-xl", + "watsonx/ibm/granite-13b-chat-v2", + "watsonx/ibm/granite-13b-instruct-v2", + "watsonx/elyza/elyza-japanese-llama-2-7b-instruct", + "watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q", + ], + "bedrock": [ + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/anthropic.claude-3-opus-20240229-v1:0", + "bedrock/anthropic.claude-v2:1", + "bedrock/anthropic.claude-v2", + "bedrock/anthropic.claude-instant-v1", + "bedrock/meta.llama3-1-405b-instruct-v1:0", + "bedrock/meta.llama3-1-70b-instruct-v1:0", + "bedrock/meta.llama3-1-8b-instruct-v1:0", + "bedrock/meta.llama3-70b-instruct-v1:0", + "bedrock/meta.llama3-8b-instruct-v1:0", + "bedrock/amazon.titan-text-lite-v1", + "bedrock/amazon.titan-text-express-v1", + "bedrock/cohere.command-text-v14", + "bedrock/ai21.j2-mid-v1", + "bedrock/ai21.j2-ultra-v1", + "bedrock/ai21.jamba-instruct-v1:0", + "bedrock/meta.llama2-13b-chat-v1", + "bedrock/meta.llama2-70b-chat-v1", + "bedrock/mistral.mistral-7b-instruct-v0:2", + "bedrock/mistral.mixtral-8x7b-instruct-v0:1", + ], } -JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" \ No newline at end of file +JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 5767b82a1..06440d74e 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,11 +1,11 @@ +import shutil import sys from pathlib import Path import click -from crewai.cli.constants import ENV_VARS +from crewai.cli.constants import ENV_VARS, MODELS from crewai.cli.provider import ( - PROVIDERS, get_provider_data, select_model, select_provider, @@ -29,20 +29,20 @@ def create_folder_structure(name, parent_folder=None): click.secho("Operation cancelled.", fg="yellow") sys.exit(0) click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True) - else: - click.secho( - f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", - fg="green", - bold=True, - ) + shutil.rmtree(folder_path) # Delete the existing folder and its contents - if not folder_path.exists(): - folder_path.mkdir(parents=True) - (folder_path / "tests").mkdir(exist_ok=True) - if not parent_folder: - (folder_path / "src" / folder_name).mkdir(parents=True) - (folder_path / "src" / folder_name / "tools").mkdir(parents=True) - (folder_path / "src" / folder_name / "config").mkdir(parents=True) + click.secho( + f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", + fg="green", + bold=True, + ) + + folder_path.mkdir(parents=True) + (folder_path / "tests").mkdir(exist_ok=True) + if not parent_folder: + (folder_path / "src" / folder_name).mkdir(parents=True) + (folder_path / "src" / folder_name / "tools").mkdir(parents=True) + (folder_path / "src" / folder_name / "config").mkdir(parents=True) return folder_path, folder_name, class_name @@ -92,7 +92,10 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): existing_provider = None for provider, env_keys in ENV_VARS.items(): - if any(key in env_vars for key in env_keys): + if any( + "key_name" in details and details["key_name"] in env_vars + for details in env_keys + ): existing_provider = provider break @@ -118,47 +121,48 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No provider selected. Please try again or press 'q' to exit.", fg="red" ) - while True: - selected_model = select_model(selected_provider, provider_models) - if selected_model is None: # User typed 'q' - click.secho("Exiting...", fg="yellow") - sys.exit(0) - if selected_model: # Valid selection - break - click.secho( - "No model selected. Please try again or press 'q' to exit.", fg="red" - ) + # Check if the selected provider has predefined models + if selected_provider in MODELS and MODELS[selected_provider]: + while True: + selected_model = select_model(selected_provider, provider_models) + if selected_model is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_model: # Valid selection + break + click.secho( + "No model selected. Please try again or press 'q' to exit.", + fg="red", + ) + env_vars["MODEL"] = selected_model - if selected_provider in PROVIDERS: - api_key_var = ENV_VARS[selected_provider][0] - else: - api_key_var = click.prompt( - f"Enter the environment variable name for your {selected_provider.capitalize()} API key", - type=str, - default="", - ) + # Check if the selected provider requires API keys + if selected_provider in ENV_VARS: + provider_env_vars = ENV_VARS[selected_provider] + for details in provider_env_vars: + if details.get("default", False): + # Automatically add default key-value pairs + for key, value in details.items(): + if key not in ["prompt", "key_name", "default"]: + env_vars[key] = value + elif "key_name" in details: + # Prompt for non-default key-value pairs + prompt = details["prompt"] + key_name = details["key_name"] + api_key_value = click.prompt(prompt, default="", show_default=False) - api_key_value = "" - click.echo( - f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ", - nl=False, - ) - try: - api_key_value = input() - except (KeyboardInterrupt, EOFError): - api_key_value = "" + if api_key_value.strip(): + env_vars[key_name] = api_key_value - if api_key_value.strip(): - env_vars = {api_key_var: api_key_value} + if env_vars: write_env_file(folder_path, env_vars) - click.secho("API key saved to .env file", fg="green") + click.secho("API keys and model saved to .env file", fg="green") else: click.secho( - "No API key provided. Skipping .env file creation.", fg="yellow" + "No API keys provided. Skipping .env file creation.", fg="yellow" ) - env_vars["MODEL"] = selected_model - click.secho(f"Selected model: {selected_model}", fg="green") + click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" diff --git a/src/crewai/cli/templates/crew/crew.py b/src/crewai/cli/templates/crew/crew.py index f950d13d4..c47315415 100644 --- a/src/crewai/cli/templates/crew/crew.py +++ b/src/crewai/cli/templates/crew/crew.py @@ -8,9 +8,12 @@ from crewai.project import CrewBase, agent, crew, task # from crewai_tools import SerperDevTool @CrewBase -class {{crew_name}}Crew(): +class {{crew_name}}(): """{{crew_name}} crew""" + agents_config = 'config/agents.yaml' + tasks_config = 'config/tasks.yaml' + @agent def researcher(self) -> Agent: return Agent( @@ -48,4 +51,4 @@ class {{crew_name}}Crew(): process=Process.sequential, verbose=True, # process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/ - ) \ No newline at end of file + ) diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 88edfcbff..d5224edcf 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -1,6 +1,10 @@ #!/usr/bin/env python import sys -from {{folder_name}}.crew import {{crew_name}}Crew +import warnings + +from {{folder_name}}.crew import {{crew_name}} + +warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") # This main file is intended to be a way for you to run your # crew locally, so refrain from adding unnecessary logic into this file. @@ -14,7 +18,7 @@ def run(): inputs = { 'topic': 'AI LLMs' } - {{crew_name}}Crew().crew().kickoff(inputs=inputs) + {{crew_name}}().crew().kickoff(inputs=inputs) def train(): @@ -25,7 +29,7 @@ def train(): "topic": "AI LLMs" } try: - {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs) + {{crew_name}}().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs) except Exception as e: raise Exception(f"An error occurred while training the crew: {e}") @@ -35,7 +39,7 @@ def replay(): Replay the crew execution from a specific task. """ try: - {{crew_name}}Crew().crew().replay(task_id=sys.argv[1]) + {{crew_name}}().crew().replay(task_id=sys.argv[1]) except Exception as e: raise Exception(f"An error occurred while replaying the crew: {e}") @@ -48,7 +52,7 @@ def test(): "topic": "AI LLMs" } try: - {{crew_name}}Crew().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs) + {{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs) except Exception as e: raise Exception(f"An error occurred while replaying the crew: {e}") diff --git a/src/crewai/crew.py b/src/crewai/crew.py index e65024ed6..7bcaa82ad 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -445,13 +445,14 @@ class Crew(BaseModel): training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() for agent in train_crew.agents: - result = TaskEvaluator(agent).evaluate_training_data( - training_data=training_data, agent_id=str(agent.id) - ) + if training_data.get(str(agent.id)): + result = TaskEvaluator(agent).evaluate_training_data( + training_data=training_data, agent_id=str(agent.id) + ) - CrewTrainingHandler(filename).save_trained_data( - agent_id=str(agent.role), trained_data=result.model_dump() - ) + CrewTrainingHandler(filename).save_trained_data( + agent_id=str(agent.role), trained_data=result.model_dump() + ) def kickoff( self, diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 9b6463d65..fa0902594 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -131,7 +131,6 @@ class FlowMeta(type): condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) - # TODO: should we add a check for __condition_type__ 'AND'? elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value) @@ -171,8 +170,7 @@ class Flow(Generic[T], metaclass=FlowMeta): def __init__(self) -> None: self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() - self._executed_methods: Set[str] = set() - self._scheduled_tasks: Set[str] = set() + self._method_execution_counts: Dict[str, int] = {} self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs @@ -309,7 +307,10 @@ class Flow(Generic[T], metaclass=FlowMeta): ) self._method_outputs.append(result) # Store the output - self._executed_methods.add(method_name) + # Track method execution counts + self._method_execution_counts[method_name] = ( + self._method_execution_counts.get(method_name, 0) + 1 + ) return result @@ -319,35 +320,34 @@ class Flow(Generic[T], metaclass=FlowMeta): if trigger_method in self._routers: router_method = self._methods[self._routers[trigger_method]] path = await self._execute_method( - trigger_method, router_method - ) # TODO: Change or not? - # Use the path as the new trigger method + self._routers[trigger_method], router_method + ) trigger_method = path for listener_name, (condition_type, methods) in self._listeners.items(): if condition_type == "OR": if trigger_method in methods: - if ( - listener_name not in self._executed_methods - and listener_name not in self._scheduled_tasks - ): - self._scheduled_tasks.add(listener_name) - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + # Schedule the listener without preventing re-execution + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) elif condition_type == "AND": - if all(method in self._executed_methods for method in methods): - if ( - listener_name not in self._executed_methods - and listener_name not in self._scheduled_tasks - ): - self._scheduled_tasks.add(listener_name) - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + # Initialize pending methods for this listener if not already done + if listener_name not in self._pending_and_listeners: + self._pending_and_listeners[listener_name] = set(methods) + # Remove the trigger method from pending methods + self._pending_and_listeners[listener_name].discard(trigger_method) + if not self._pending_and_listeners[listener_name]: + # All required methods have been executed + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) + # Reset pending methods for this listener + self._pending_and_listeners.pop(listener_name, None) # Run all listener tasks concurrently and wait for them to complete - await asyncio.gather(*listener_tasks) + if listener_tasks: + await asyncio.gather(*listener_tasks) async def _execute_single_listener(self, listener_name: str, result: Any) -> None: try: @@ -367,9 +367,6 @@ class Flow(Generic[T], metaclass=FlowMeta): # If listener does not expect parameters, call without arguments listener_result = await self._execute_method(listener_name, method) - # Remove from scheduled tasks after execution - self._scheduled_tasks.discard(listener_name) - # Execute listeners of this listener await self._execute_listeners(listener_name, listener_result) except Exception as e: diff --git a/tests/flow_test.py b/tests/flow_test.py new file mode 100644 index 000000000..ffd82367c --- /dev/null +++ b/tests/flow_test.py @@ -0,0 +1,264 @@ +"""Test Flow creation and execution basic functionality.""" + +import asyncio + +import pytest +from crewai.flow.flow import Flow, and_, listen, or_, router, start + + +def test_simple_sequential_flow(): + """Test a simple flow with two steps called sequentially.""" + execution_order = [] + + class SimpleFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = SimpleFlow() + flow.kickoff() + + assert execution_order == ["step_1", "step_2"] + + +def test_flow_with_multiple_starts(): + """Test a flow with multiple start methods.""" + execution_order = [] + + class MultiStartFlow(Flow): + @start() + def step_a(self): + execution_order.append("step_a") + + @start() + def step_b(self): + execution_order.append("step_b") + + @listen(step_a) + def step_c(self): + execution_order.append("step_c") + + @listen(step_b) + def step_d(self): + execution_order.append("step_d") + + flow = MultiStartFlow() + flow.kickoff() + + assert "step_a" in execution_order + assert "step_b" in execution_order + assert "step_c" in execution_order + assert "step_d" in execution_order + assert execution_order.index("step_c") > execution_order.index("step_a") + assert execution_order.index("step_d") > execution_order.index("step_b") + + +def test_cyclic_flow(): + """Test a cyclic flow that runs a finite number of iterations.""" + execution_order = [] + + class CyclicFlow(Flow): + iteration = 0 + max_iterations = 3 + + @start("loop") + def step_1(self): + if self.iteration >= self.max_iterations: + return # Do not proceed further + execution_order.append(f"step_1_{self.iteration}") + + @listen(step_1) + def step_2(self): + execution_order.append(f"step_2_{self.iteration}") + + @router(step_2) + def step_3(self): + execution_order.append(f"step_3_{self.iteration}") + self.iteration += 1 + if self.iteration < self.max_iterations: + return "loop" + + return "exit" + + flow = CyclicFlow() + flow.kickoff() + + expected_order = [] + for i in range(flow.max_iterations): + expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"]) + + assert execution_order == expected_order + + +def test_flow_with_and_condition(): + """Test a flow where a step waits for multiple other steps to complete.""" + execution_order = [] + + class AndConditionFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @start() + def step_2(self): + execution_order.append("step_2") + + @listen(and_(step_1, step_2)) + def step_3(self): + execution_order.append("step_3") + + flow = AndConditionFlow() + flow.kickoff() + + assert "step_1" in execution_order + assert "step_2" in execution_order + assert execution_order[-1] == "step_3" + assert execution_order.index("step_3") > execution_order.index("step_1") + assert execution_order.index("step_3") > execution_order.index("step_2") + + +def test_flow_with_or_condition(): + """Test a flow where a step is triggered when any of multiple steps complete.""" + execution_order = [] + + class OrConditionFlow(Flow): + @start() + def step_a(self): + execution_order.append("step_a") + + @start() + def step_b(self): + execution_order.append("step_b") + + @listen(or_(step_a, step_b)) + def step_c(self): + execution_order.append("step_c") + + flow = OrConditionFlow() + flow.kickoff() + + assert "step_a" in execution_order or "step_b" in execution_order + assert "step_c" in execution_order + assert execution_order.index("step_c") > min( + execution_order.index("step_a"), execution_order.index("step_b") + ) + + +def test_flow_with_router(): + """Test a flow that uses a router method to determine the next step.""" + execution_order = [] + + class RouterFlow(Flow): + @start() + def start_method(self): + execution_order.append("start_method") + + @router(start_method) + def router(self): + execution_order.append("router") + # Ensure the condition is set to True to follow the "step_if_true" path + condition = True + return "step_if_true" if condition else "step_if_false" + + @listen("step_if_true") + def truthy(self): + execution_order.append("step_if_true") + + @listen("step_if_false") + def falsy(self): + execution_order.append("step_if_false") + + flow = RouterFlow() + flow.kickoff() + + assert execution_order == ["start_method", "router", "step_if_true"] + + +def test_async_flow(): + """Test an asynchronous flow.""" + execution_order = [] + + class AsyncFlow(Flow): + @start() + async def step_1(self): + execution_order.append("step_1") + await asyncio.sleep(0.1) + + @listen(step_1) + async def step_2(self): + execution_order.append("step_2") + await asyncio.sleep(0.1) + + flow = AsyncFlow() + asyncio.run(flow.kickoff_async()) + + assert execution_order == ["step_1", "step_2"] + + +def test_flow_with_exceptions(): + """Test flow behavior when exceptions occur in steps.""" + execution_order = [] + + class ExceptionFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + raise ValueError("An error occurred in step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = ExceptionFlow() + + with pytest.raises(ValueError): + flow.kickoff() + + # Ensure step_2 did not execute + assert execution_order == ["step_1"] + + +def test_flow_restart(): + """Test restarting a flow after it has completed.""" + execution_order = [] + + class RestartableFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = RestartableFlow() + flow.kickoff() + flow.kickoff() # Restart the flow + + assert execution_order == ["step_1", "step_2", "step_1", "step_2"] + + +def test_flow_with_custom_state(): + """Test a flow that maintains and modifies internal state.""" + + class StateFlow(Flow): + def __init__(self): + super().__init__() + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = StateFlow() + flow.kickoff() + assert flow.counter == 2