From 8fb08a0b25cefa8a1affce79c50b12dd1a8998a4 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Tue, 1 Oct 2024 14:20:26 -0400 Subject: [PATCH] Flow visualizer (#1377) * Almost working! * It fully works but not clean enought * Working but not clean engouth * Everything is workign * WIP. Working on adding and & or to flows. In the middle of setting up template for flow as well * template working * Everything is working * More changes and todos * Add more support for @start * Router working now * minor tweak to * minor tweak to conditions and event handling * Update logs * Too trigger happy with cleanup * Added in Thiago fix * Flow passing results again * Working on docs. * made more progress updates on docs * Finished talking about controlling flows * add flow output * fixed flow output section * add crews to flows section is looking good now * more flow doc changes * Update docs and add more examples * drop visualizer * save visualizer * pyvis is beginning to work * pyvis working * it is working * regular methods and triggers working. Need to work on router next. * properly identifying router and router children nodes. Need to fix color * children router working. Need to support loops * curving cycles but need to add curve conditionals * everythin is showing up properly need to fix curves * all working. needs to be cleaned up * adjust padding * drop lib * clean up prior to PR * incorporate joao feedback * final tweaks for joao * Refactor to make crews easier to understand * update CLI and templates * Fix crewai version in flows * Fix merge conflict --- poetry.lock | 83 +--- src/crewai/cli/cli.py | 24 +- src/crewai/cli/plot_flow.py | 23 + src/crewai/cli/run_flow.py | 23 + src/crewai/cli/templates/flow/main.py | 19 +- src/crewai/cli/templates/flow/pyproject.toml | 5 +- src/crewai/flow/config.py | 46 ++ src/crewai/flow/flow.py | 8 +- src/crewai/flow/flow_visualizer.py | 473 ++----------------- src/crewai/flow/html_template_handler.py | 66 +++ src/crewai/flow/legend_generator.py | 46 ++ src/crewai/flow/utils.py | 143 ++++++ src/crewai/flow/visualization_utils.py | 132 ++++++ 13 files changed, 592 insertions(+), 499 deletions(-) create mode 100644 src/crewai/cli/plot_flow.py create mode 100644 src/crewai/cli/run_flow.py create mode 100644 src/crewai/flow/config.py create mode 100644 src/crewai/flow/html_template_handler.py create mode 100644 src/crewai/flow/legend_generator.py create mode 100644 src/crewai/flow/utils.py create mode 100644 src/crewai/flow/visualization_utils.py diff --git a/poetry.lock b/poetry.lock index a4d1b8ca2..354b48e9f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1936,13 +1936,13 @@ test = ["objgraph", "psutil"] [[package]] name = "griffe" -version = "1.3.1" +version = "1.3.2" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." optional = false python-versions = ">=3.8" files = [ - {file = "griffe-1.3.1-py3-none-any.whl", hash = "sha256:940aeb630bc3054b4369567f150b6365be6f11eef46b0ed8623aea96e6d17b19"}, - {file = "griffe-1.3.1.tar.gz", hash = "sha256:3f86a716b631a4c0f96a43cb75d05d3c85975003c20540426c0eba3b0581c56a"}, + {file = "griffe-1.3.2-py3-none-any.whl", hash = "sha256:2e34b5e46507d615915c8e6288bb1a2234bd35dee44d01e40a2bc2f25bd4d10c"}, + {file = "griffe-1.3.2.tar.gz", hash = "sha256:1ec50335aa507ed2445f2dd45a15c9fa3a45f52c9527e880571dfc61912fd60c"}, ] [package.dependencies] @@ -2148,13 +2148,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.5" +version = "1.0.6" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, ] [package.dependencies] @@ -2165,7 +2165,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] +trio = ["trio (>=0.22.0,<1.0)"] [[package]] name = "httptools" @@ -3891,13 +3891,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.50.2" +version = "1.51.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.50.2-py3-none-any.whl", hash = "sha256:822dd2051baa3393d0d5406990611975dd6f533020dc9375a34d4fe67e8b75f7"}, - {file = "openai-1.50.2.tar.gz", hash = "sha256:3987ae027152fc8bea745d60b02c8f4c4a76e1b5c70e73565fa556db6f78c9e6"}, + {file = "openai-1.51.0-py3-none-any.whl", hash = "sha256:d9affafb7e51e5a27dce78589d4964ce4d6f6d560307265933a94b2e3f3c5d2c"}, + {file = "openai-1.51.0.tar.gz", hash = "sha256:8dc4f9d75ccdd5466fc8c99a952186eddceb9fd6ba694044773f3736a847149d"}, ] [package.dependencies] @@ -5058,13 +5058,13 @@ torch = ["torch"] [[package]] name = "pymdown-extensions" -version = "10.11.1" +version = "10.11.2" description = "Extension pack for Python Markdown." optional = false python-versions = ">=3.8" files = [ - {file = "pymdown_extensions-10.11.1-py3-none-any.whl", hash = "sha256:a2b28f5786e041f19cb5bb30a1c2c853668a7099da8e3dd822a5ad05f2e855e3"}, - {file = "pymdown_extensions-10.11.1.tar.gz", hash = "sha256:a8836e955851542fa2625d04d59fdf97125ca001377478ed5618e04f9183a59a"}, + {file = "pymdown_extensions-10.11.2-py3-none-any.whl", hash = "sha256:41cdde0a77290e480cf53892f5c5e50921a7ee3e5cd60ba91bf19837b33badcf"}, + {file = "pymdown_extensions-10.11.2.tar.gz", hash = "sha256:bc8847ecc9e784a098efd35e20cba772bc5a1b529dfcef9dc1972db9021a1049"}, ] [package.dependencies] @@ -5747,18 +5747,19 @@ py = ">=1.4.26,<2.0.0" [[package]] name = "rich" -version = "13.8.1" +version = "13.9.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, - {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, + {file = "rich-13.9.1-py3-none-any.whl", hash = "sha256:b340e739f30aa58921dc477b8adaa9ecdb7cecc217be01d93730ee1bc8aa83be"}, + {file = "rich-13.9.1.tar.gz", hash = "sha256:097cffdf85db1babe30cc7deba5ab3a29e1b9885047dab24c57e9a7f8a9c1466"}, ] [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -6102,54 +6103,6 @@ description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.35-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:67219632be22f14750f0d1c70e62f204ba69d28f62fd6432ba05ab295853de9b"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4668bd8faf7e5b71c0319407b608f278f279668f358857dbfd10ef1954ac9f90"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb8bea573863762bbf45d1e13f87c2d2fd32cee2dbd50d050f83f87429c9e1ea"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f552023710d4b93d8fb29a91fadf97de89c5926c6bd758897875435f2a939f33"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:016b2e665f778f13d3c438651dd4de244214b527a275e0acf1d44c05bc6026a9"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7befc148de64b6060937231cbff8d01ccf0bfd75aa26383ffdf8d82b12ec04ff"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-win32.whl", hash = "sha256:22b83aed390e3099584b839b93f80a0f4a95ee7f48270c97c90acd40ee646f0b"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-win_amd64.whl", hash = "sha256:a29762cd3d116585278ffb2e5b8cc311fb095ea278b96feef28d0b423154858e"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e21f66748ab725ade40fa7af8ec8b5019c68ab00b929f6643e1b1af461eddb60"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8a6219108a15fc6d24de499d0d515c7235c617b2540d97116b663dade1a54d62"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042622a5306c23b972192283f4e22372da3b8ddf5f7aac1cc5d9c9b222ab3ff6"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:627dee0c280eea91aed87b20a1f849e9ae2fe719d52cbf847c0e0ea34464b3f7"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4fdcd72a789c1c31ed242fd8c1bcd9ea186a98ee8e5408a50e610edfef980d71"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:89b64cd8898a3a6f642db4eb7b26d1b28a497d4022eccd7717ca066823e9fb01"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-win32.whl", hash = "sha256:6a93c5a0dfe8d34951e8a6f499a9479ffb9258123551fa007fc708ae2ac2bc5e"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-win_amd64.whl", hash = "sha256:c68fe3fcde03920c46697585620135b4ecfdfc1ed23e75cc2c2ae9f8502c10b8"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eb60b026d8ad0c97917cb81d3662d0b39b8ff1335e3fabb24984c6acd0c900a2"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6921ee01caf375363be5e9ae70d08ce7ca9d7e0e8983183080211a062d299468"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8cdf1a0dbe5ced887a9b127da4ffd7354e9c1a3b9bb330dce84df6b70ccb3a8d"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93a71c8601e823236ac0e5d087e4f397874a421017b3318fd92c0b14acf2b6db"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e04b622bb8a88f10e439084486f2f6349bf4d50605ac3e445869c7ea5cf0fa8c"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1b56961e2d31389aaadf4906d453859f35302b4eb818d34a26fab72596076bb8"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-win32.whl", hash = "sha256:0f9f3f9a3763b9c4deb8c5d09c4cc52ffe49f9876af41cc1b2ad0138878453cf"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-win_amd64.whl", hash = "sha256:25b0f63e7fcc2a6290cb5f7f5b4fc4047843504983a28856ce9b35d8f7de03cc"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f021d334f2ca692523aaf7bbf7592ceff70c8594fad853416a81d66b35e3abf9"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05c3f58cf91683102f2f0265c0db3bd3892e9eedabe059720492dbaa4f922da1"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:032d979ce77a6c2432653322ba4cbeabf5a6837f704d16fa38b5a05d8e21fa00"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:2e795c2f7d7249b75bb5f479b432a51b59041580d20599d4e112b5f2046437a3"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:cc32b2990fc34380ec2f6195f33a76b6cdaa9eecf09f0c9404b74fc120aef36f"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-win32.whl", hash = "sha256:9509c4123491d0e63fb5e16199e09f8e262066e58903e84615c301dde8fa2e87"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-win_amd64.whl", hash = "sha256:3655af10ebcc0f1e4e06c5900bb33e080d6a1fa4228f502121f28a3b1753cde5"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4c31943b61ed8fdd63dfd12ccc919f2bf95eefca133767db6fbbd15da62078ec"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a62dd5d7cc8626a3634208df458c5fe4f21200d96a74d122c83bc2015b333bc1"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0630774b0977804fba4b6bbea6852ab56c14965a2b0c7fc7282c5f7d90a1ae72"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d625eddf7efeba2abfd9c014a22c0f6b3796e0ffb48f5d5ab106568ef01ff5a"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ada603db10bb865bbe591939de854faf2c60f43c9b763e90f653224138f910d9"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c41411e192f8d3ea39ea70e0fae48762cd11a2244e03751a98bd3c0ca9a4e936"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-win32.whl", hash = "sha256:d299797d75cd747e7797b1b41817111406b8b10a4f88b6e8fe5b5e59598b43b0"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-win_amd64.whl", hash = "sha256:0375a141e1c0878103eb3d719eb6d5aa444b490c96f3fedab8471c7f6ffe70ee"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccae5de2a0140d8be6838c331604f91d6fafd0735dbdcee1ac78fc8fbaba76b4"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a275a806f73e849e1c309ac11108ea1a14cd7058577aba962cd7190e27c9e3c"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:732e026240cdd1c1b2e3ac515c7a23820430ed94292ce33806a95869c46bd139"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:890da8cd1941fa3dab28c5bac3b9da8502e7e366f895b3b8e500896f12f94d11"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c0d8326269dbf944b9201911b0d9f3dc524d64779a07518199a58384c3d37a44"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b76d63495b0508ab9fc23f8152bac63205d2a704cd009a2b0722f4c8e0cba8e0"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-win32.whl", hash = "sha256:69683e02e8a9de37f17985905a5eca18ad651bf592314b4d3d799029797d0eb3"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-win_amd64.whl", hash = "sha256:aee110e4ef3c528f3abbc3c2018c121e708938adeeff9006428dd7c8555e9b3f"}, - {file = "SQLAlchemy-2.0.35-py3-none-any.whl", hash = "sha256:2ab3f0336c0387662ce6221ad30ab3a5e6499aab01b9790879b6578fd9b8faa1"}, {file = "sqlalchemy-2.0.35.tar.gz", hash = "sha256:e11d7ea4d24f0a262bccf9a7cd6284c976c5369dac21db237cff59586045ab9f"}, ] diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 22e1aad3c..de6160ba6 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -12,12 +12,14 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( from .authentication.main import AuthenticationCommand from .deploy.main import DeployCommand -from .tools.main import ToolCommand from .evaluate_crew import evaluate_crew from .install_crew import install_crew +from .plot_flow import plot_flow from .replay_from_task import replay_task_command from .reset_memories_command import reset_memories_command from .run_crew import run_crew +from .run_flow import run_flow +from .tools.main import ToolCommand from .train_crew import train_crew @@ -273,5 +275,25 @@ def tool_publish(is_public: bool): tool_cmd.publish(is_public) +@crewai.group() +def flow(): + """Flow related commands.""" + pass + + +@flow.command(name="run") +def flow_run(): + """Run the Flow.""" + click.echo("Running the Flow") + run_flow() + + +@flow.command(name="plot") +def flow_plot(): + """Plot the Flow.""" + click.echo("Plotting the Flow") + plot_flow() + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/plot_flow.py b/src/crewai/cli/plot_flow.py new file mode 100644 index 000000000..bb7b2052f --- /dev/null +++ b/src/crewai/cli/plot_flow.py @@ -0,0 +1,23 @@ +import subprocess + +import click + + +def plot_flow() -> None: + """ + Plot the flow by running a command in the Poetry environment. + """ + command = ["poetry", "run", "plot_flow"] + + try: + result = subprocess.run(command, capture_output=False, text=True, check=True) + + if result.stderr: + click.echo(result.stderr, err=True) + + except subprocess.CalledProcessError as e: + click.echo(f"An error occurred while plotting the flow: {e}", err=True) + click.echo(e.output, err=True) + + except Exception as e: + click.echo(f"An unexpected error occurred: {e}", err=True) diff --git a/src/crewai/cli/run_flow.py b/src/crewai/cli/run_flow.py new file mode 100644 index 000000000..3a9e72817 --- /dev/null +++ b/src/crewai/cli/run_flow.py @@ -0,0 +1,23 @@ +import subprocess + +import click + + +def run_flow() -> None: + """ + Run the flow by running a command in the Poetry environment. + """ + command = ["poetry", "run", "run_flow"] + + try: + result = subprocess.run(command, capture_output=False, text=True, check=True) + + if result.stderr: + click.echo(result.stderr, err=True) + + except subprocess.CalledProcessError as e: + click.echo(f"An error occurred while running the flow: {e}", err=True) + click.echo(e.output, err=True) + + except Exception as e: + click.echo(f"An unexpected error occurred: {e}", err=True) diff --git a/src/crewai/cli/templates/flow/main.py b/src/crewai/cli/templates/flow/main.py index bda89065d..38d2d8736 100644 --- a/src/crewai/cli/templates/flow/main.py +++ b/src/crewai/cli/templates/flow/main.py @@ -22,8 +22,7 @@ class PoemFlow(Flow[PoemState]): def generate_poem(self): print("Generating poem") print(f"State before poem: {self.state}") - poem_crew = PoemCrew().crew() - result = poem_crew.kickoff(inputs={"sentence_count": self.state.sentence_count}) + result = PoemCrew().crew().kickoff(inputs={"sentence_count": self.state.sentence_count}) print("Poem generated", result.raw) self.state.poem = result.raw @@ -38,16 +37,28 @@ class PoemFlow(Flow[PoemState]): f.write(self.state.poem) print(f"State after save_poem: {self.state}") -async def run(): +async def run_flow(): """ Run the flow. """ poem_flow = PoemFlow() await poem_flow.kickoff() +async def plot_flow(): + """ + Plot the flow. + """ + poem_flow = PoemFlow() + poem_flow.plot() + def main(): - asyncio.run(run()) + asyncio.run(run_flow()) + + +def plot(): + asyncio.run(plot_flow()) + if __name__ == "__main__": diff --git a/src/crewai/cli/templates/flow/pyproject.toml b/src/crewai/cli/templates/flow/pyproject.toml index 0e7083c71..3753cf941 100644 --- a/src/crewai/cli/templates/flow/pyproject.toml +++ b/src/crewai/cli/templates/flow/pyproject.toml @@ -6,12 +6,13 @@ authors = ["Your Name "] [tool.poetry.dependencies] python = ">=3.10,<=3.13" -crewai = { extras = ["tools"], version = ">=0.55.2,<1.0.0" } +crewai = { extras = ["tools"], version = ">=0.66.0,<1.0.0" } asyncio = "*" [tool.poetry.scripts] {{folder_name}} = "{{folder_name}}.main:main" -run_crew = "{{folder_name}}.main:main" +run_flow = "{{folder_name}}.main:main" +plot_flow = "{{folder_name}}.main:plot" [build-system] requires = ["poetry-core"] diff --git a/src/crewai/flow/config.py b/src/crewai/flow/config.py new file mode 100644 index 000000000..ddaddc7a8 --- /dev/null +++ b/src/crewai/flow/config.py @@ -0,0 +1,46 @@ +DARK_GRAY = "#333333" +CREWAI_ORANGE = "#FF5A50" +GRAY = "#666666" +WHITE = "#FFFFFF" + +COLORS = { + "bg": WHITE, + "start": CREWAI_ORANGE, + "method": DARK_GRAY, + "router": DARK_GRAY, + "router_border": CREWAI_ORANGE, + "edge": GRAY, + "router_edge": CREWAI_ORANGE, + "text": WHITE, +} + +NODE_STYLES = { + "start": { + "color": COLORS["start"], + "shape": "box", + "font": {"color": COLORS["text"]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, + "method": { + "color": COLORS["method"], + "shape": "box", + "font": {"color": COLORS["text"]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, + "router": { + "color": { + "background": COLORS["router"], + "border": COLORS["router_border"], + "highlight": { + "border": COLORS["router_border"], + "background": COLORS["router"], + }, + }, + "shape": "box", + "font": {"color": COLORS["text"]}, + "borderWidth": 3, + "borderWidthSelected": 4, + "shapeProperties": {"borderDashes": [5, 5]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, +} diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 4dd761137..f60e5c1e3 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,12 +1,14 @@ # flow.py +# flow.py + import asyncio import inspect from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union from pydantic import BaseModel -from crewai.flow.flow_visualizer import visualize_flow +from crewai.flow.flow_visualizer import plot_flow T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) @@ -268,5 +270,5 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() - def visualize(self, filename: str = "crewai_flow_graph"): - visualize_flow(self, filename) + def plot(self, filename: str = "crewai_flow_graph"): + plot_flow(self, filename) diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 8b00d8822..822f192b0 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,63 +1,27 @@ # flow_visualizer.py -import base64 import os -import re from pyvis.network import Network -DARK_GRAY = "#333333" -CREWAI_ORANGE = "#FF5A50" -GRAY = "#666666" -WHITE = "#FFFFFF" +from crewai.flow.config import COLORS, NODE_STYLES +from crewai.flow.html_template_handler import HTMLTemplateHandler +from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items +from crewai.flow.utils import calculate_node_levels +from crewai.flow.visualization_utils import ( + add_edges, + add_nodes_to_network, + compute_positions, +) -class FlowVisualizer: +class FlowPlot: def __init__(self, flow): self.flow = flow - self.colors = { - "bg": WHITE, - "start": CREWAI_ORANGE, - "method": DARK_GRAY, - "router": DARK_GRAY, - "router_border": CREWAI_ORANGE, - "edge": GRAY, - "router_edge": CREWAI_ORANGE, - "text": WHITE, - } - self.node_styles = { - "start": { - "color": self.colors["start"], - "shape": "box", - "font": {"color": self.colors["text"]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - "method": { - "color": self.colors["method"], - "shape": "box", - "font": {"color": self.colors["text"]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - "router": { - "color": { - "background": self.colors["router"], - "border": self.colors["router_border"], - "highlight": { - "border": self.colors["router_border"], - "background": self.colors["router"], - }, - }, - "shape": "box", - "font": {"color": self.colors["text"]}, - "borderWidth": 3, - "borderWidthSelected": 4, - "shapeProperties": {"borderDashes": [5, 5]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - } + self.colors = COLORS + self.node_styles = NODE_STYLES - # TODO: DROP LIB FOLDER POST GENERATION - def visualize(self, filename): + def plot(self, filename): net = Network( directed=True, height="750px", @@ -67,172 +31,16 @@ class FlowVisualizer: ) # Calculate levels for nodes - node_levels = self._calculate_node_levels() - - # Assign positions to nodes based on levels - y_spacing = 150 - x_spacing = 150 - level_nodes = {} - - # Store node positions for edge calculations - node_positions = {} - - for method_name, level in node_levels.items(): - level_nodes.setdefault(level, []).append(method_name) + node_levels = calculate_node_levels(self.flow) # Compute positions - for level, nodes in level_nodes.items(): - x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally - for i, method_name in enumerate(nodes): - x = x_offset + i * x_spacing - y = level * y_spacing - node_positions[method_name] = (x, y) + node_positions = compute_positions(self.flow, node_levels) - method = self.flow._methods.get(method_name) - if hasattr(method, "__is_start_method__"): - node_style = self.node_styles["start"] - elif hasattr(method, "__is_router__"): - node_style = self.node_styles["router"] - else: - node_style = self.node_styles["method"] + # Add nodes to the network + add_nodes_to_network(net, self.flow, node_positions, self.node_styles) - net.add_node( - method_name, - label=method_name, - x=x, - y=y, - fixed=True, - physics=False, - **node_style, - ) - - ancestors = self._build_ancestor_dict() - parent_children = self._build_parent_children_dict() - - # Add edges - for method_name in self.flow._listeners: - condition_type, trigger_methods = self.flow._listeners[method_name] - is_and_condition = condition_type == "AND" - - for trigger in trigger_methods: - if ( - trigger in self.flow._methods - or trigger in self.flow._routers.values() - ): - is_router_edge = any( - trigger in paths for paths in self.flow._router_paths.values() - ) - edge_color = ( - self.colors["router_edge"] - if is_router_edge - else self.colors["edge"] - ) - - # Determine if this edge forms a cycle - is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors) - - # Determine if parent has multiple children - parent_has_multiple_children = ( - len(parent_children.get(trigger, [])) > 1 - ) - - # Edge curvature logic - needs_curvature = is_cycle_edge or parent_has_multiple_children - - if needs_curvature: - # Get node positions - source_pos = node_positions.get(trigger) - target_pos = node_positions.get(method_name) - - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - - if dx <= 0: - # Child is left or directly below - smooth_type = "curvedCCW" # Curve left and down - else: - # Child is to the right - smooth_type = "curvedCW" # Curve right and down - - index = self._get_child_index( - trigger, method_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } - else: - # Fallback curvature - edge_smooth = {"type": "cubicBezier"} - else: - edge_smooth = False # Draw straight line - - edge_style = { - "color": edge_color, - "width": 2, - "arrows": "to", - "dashes": True if is_router_edge or is_and_condition else False, - "smooth": edge_smooth, - } - - net.add_edge(trigger, method_name, **edge_style) - - # Add edges from router methods to their possible paths - for router_method_name, paths in self.flow._router_paths.items(): - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if path in trigger_methods: - is_cycle_edge = self._is_ancestor( - trigger, method_name, ancestors - ) - - # Determine if parent has multiple children - parent_has_multiple_children = ( - len(parent_children.get(router_method_name, [])) > 1 - ) - - # Edge curvature logic - needs_curvature = is_cycle_edge or parent_has_multiple_children - - if needs_curvature: - # Get node positions - source_pos = node_positions.get(router_method_name) - target_pos = node_positions.get(listener_name) - - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - - if dx <= 0: - # Child is left or directly below - smooth_type = "curvedCCW" # Curve left and down - else: - # Child is to the right - smooth_type = "curvedCW" # Curve right and down - - index = self._get_child_index( - router_method_name, listener_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } - else: - # Fallback curvature - edge_smooth = {"type": "cubicBezier"} - else: - edge_smooth = False # Straight line - - edge_style = { - "color": self.colors["router_edge"], - "width": 2, - "arrows": "to", - "dashes": True, - "smooth": edge_smooth, - } - net.add_edge(router_method_name, listener_name, **edge_style) + # Add edges to the network + add_edges(net, self.flow, node_positions, self.colors) # Set options to disable physics net.set_options( @@ -246,229 +54,46 @@ class FlowVisualizer: ) network_html = net.generate_html() - - # Extract just the body content from the generated HTML - match = re.search("(.*?)", network_html, re.DOTALL) - if match: - network_body = match.group(1) - else: - network_body = "" - - # Read the custom template - current_dir = os.path.dirname(__file__) - template_path = os.path.join( - current_dir, "assets", "crewai_flow_visual_template.html" - ) - with open(template_path, "r", encoding="utf-8") as f: - html_template = f.read() - - # Generate the legend items HTML - legend_items = [ - {"label": "Start Method", "color": self.colors["start"]}, - {"label": "Method", "color": self.colors["method"]}, - { - "label": "Router", - "color": self.colors["router"], - "border": self.colors["router_border"], - "dashed": True, - }, - {"label": "Trigger", "color": self.colors["edge"], "dashed": False}, - {"label": "AND Trigger", "color": self.colors["edge"], "dashed": True}, - { - "label": "Router Trigger", - "color": self.colors["router_edge"], - "dashed": True, - }, - ] - - legend_items_html = "" - for item in legend_items: - if "border" in item: - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - elif item.get("dashed") is not None: - style = "dashed" if item["dashed"] else "solid" - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - else: - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - - # Read the logo file and encode it - logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") - with open(logo_path, "rb") as logo_file: - logo_svg_data = logo_file.read() - logo_svg_base64 = base64.b64encode(logo_svg_data).decode("utf-8") - - # Replace placeholders in the template - final_html_content = html_template.replace("{{ title }}", "Flow Graph") - final_html_content = final_html_content.replace( - "{{ network_content }}", network_body - ) - final_html_content = final_html_content.replace( - "{{ logo_svg_base64 }}", logo_svg_base64 - ) - final_html_content = final_html_content.replace( - "", legend_items_html - ) + final_html_content = self._generate_final_html(network_html) # Save the final HTML content to the file with open(f"{filename}.html", "w", encoding="utf-8") as f: f.write(final_html_content) print(f"Graph saved as {filename}.html") - def _calculate_node_levels(self): - levels = {} - queue = [] - visited = set() - pending_and_listeners = {} + self._cleanup_pyvis_lib() - # Make all start methods at level 0 - for method_name, method in self.flow._methods.items(): - if hasattr(method, "__is_start_method__"): - levels[method_name] = 0 - queue.append(method_name) + def _generate_final_html(self, network_html): + # Extract just the body content from the generated HTML + current_dir = os.path.dirname(__file__) + template_path = os.path.join( + current_dir, "assets", "crewai_flow_visual_template.html" + ) + logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") - # Breadth-first traversal to assign levels - while queue: - current = queue.pop(0) - current_level = levels[current] - visited.add(current) + html_handler = HTMLTemplateHandler(template_path, logo_path) + network_body = html_handler.extract_body_content(network_html) - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if condition_type == "OR": - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - elif condition_type == "AND": - if listener_name not in pending_and_listeners: - pending_and_listeners[listener_name] = set() - if current in trigger_methods: - pending_and_listeners[listener_name].add(current) - if set(trigger_methods) == pending_and_listeners[listener_name]: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) + # Generate the legend items HTML + legend_items = get_legend_items(self.colors) + legend_items_html = generate_legend_items_html(legend_items) + final_html_content = html_handler.generate_final_html( + network_body, legend_items_html + ) + return final_html_content - # Handle router connections - if current in self.flow._routers.values(): - router_method_name = current - paths = self.flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - return levels + def _cleanup_pyvis_lib(self): + # Clean up the generated lib folder + lib_folder = os.path.join(os.getcwd(), "lib") + try: + if os.path.exists(lib_folder) and os.path.isdir(lib_folder): + import shutil - def _count_outgoing_edges(self): - counts = {} - for method_name in self.flow._methods: - counts[method_name] = 0 - for method_name in self.flow._listeners: - _, trigger_methods = self.flow._listeners[method_name] - for trigger in trigger_methods: - if trigger in self.flow._methods: - counts[trigger] += 1 - return counts - - def _build_ancestor_dict(self): - ancestors = {node: set() for node in self.flow._methods} - visited = set() - for node in self.flow._methods: - if node not in visited: - self._dfs_ancestors(node, ancestors, visited) - - return ancestors - - def _dfs_ancestors(self, node, ancestors, visited): - if node in visited: - return - visited.add(node) - - # Handle regular listeners - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if node in trigger_methods: - ancestors[listener_name].add(node) - ancestors[listener_name].update(ancestors[node]) - self._dfs_ancestors(listener_name, ancestors, visited) - - # Handle router methods separately - if node in self.flow._routers.values(): - router_method_name = node - paths = self.flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if path in trigger_methods: - # Only propagate the ancestors of the router method, not the router method itself - ancestors[listener_name].update(ancestors[node]) - self._dfs_ancestors(listener_name, ancestors, visited) - - def _is_ancestor(self, node, ancestor_candidate, ancestors): - return ancestor_candidate in ancestors.get(node, set()) - - def _build_parent_children_dict(self): - parent_children = {} - - # Map listeners to their trigger methods - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - for trigger in trigger_methods: - if trigger not in parent_children: - parent_children[trigger] = [] - if listener_name not in parent_children[trigger]: - parent_children[trigger].append(listener_name) - - # Map router methods to their paths and to listeners - for router_method_name, paths in self.flow._router_paths.items(): - for path in paths: - # Map router method to listeners of each path - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if path in trigger_methods: - if router_method_name not in parent_children: - parent_children[router_method_name] = [] - if listener_name not in parent_children[router_method_name]: - parent_children[router_method_name].append(listener_name) - - return parent_children - - def _get_child_index(self, parent, child, parent_children): - children = parent_children.get(parent, []) - children.sort() - return children.index(child) + shutil.rmtree(lib_folder) + except Exception as e: + print(f"Error cleaning up {lib_folder}: {e}") -def visualize_flow(flow, filename="flow_graph"): - visualizer = FlowVisualizer(flow) - visualizer.visualize(filename) +def plot_flow(flow, filename="flow_graph"): + visualizer = FlowPlot(flow) + visualizer.plot(filename) diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py new file mode 100644 index 000000000..8a88da42a --- /dev/null +++ b/src/crewai/flow/html_template_handler.py @@ -0,0 +1,66 @@ +import base64 +import os +import re + + +class HTMLTemplateHandler: + def __init__(self, template_path, logo_path): + self.template_path = template_path + self.logo_path = logo_path + + def read_template(self): + with open(self.template_path, "r", encoding="utf-8") as f: + return f.read() + + def encode_logo(self): + with open(self.logo_path, "rb") as logo_file: + logo_svg_data = logo_file.read() + return base64.b64encode(logo_svg_data).decode("utf-8") + + def extract_body_content(self, html): + match = re.search("(.*?)", html, re.DOTALL) + return match.group(1) if match else "" + + def generate_legend_items_html(self, legend_items): + legend_items_html = "" + for item in legend_items: + if "border" in item: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + elif item.get("dashed") is not None: + style = "dashed" if item["dashed"] else "solid" + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + return legend_items_html + + def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"): + html_template = self.read_template() + logo_svg_base64 = self.encode_logo() + + final_html_content = html_template.replace("{{ title }}", title) + final_html_content = final_html_content.replace( + "{{ network_content }}", network_body + ) + final_html_content = final_html_content.replace( + "{{ logo_svg_base64 }}", logo_svg_base64 + ) + final_html_content = final_html_content.replace( + "", legend_items_html + ) + + return final_html_content diff --git a/src/crewai/flow/legend_generator.py b/src/crewai/flow/legend_generator.py new file mode 100644 index 000000000..83d9b97a2 --- /dev/null +++ b/src/crewai/flow/legend_generator.py @@ -0,0 +1,46 @@ +def get_legend_items(colors): + return [ + {"label": "Start Method", "color": colors["start"]}, + {"label": "Method", "color": colors["method"]}, + { + "label": "Router", + "color": colors["router"], + "border": colors["router_border"], + "dashed": True, + }, + {"label": "Trigger", "color": colors["edge"], "dashed": False}, + {"label": "AND Trigger", "color": colors["edge"], "dashed": True}, + { + "label": "Router Trigger", + "color": colors["router_edge"], + "dashed": True, + }, + ] + + +def generate_legend_items_html(legend_items): + legend_items_html = "" + for item in legend_items: + if "border" in item: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + elif item.get("dashed") is not None: + style = "dashed" if item["dashed"] else "solid" + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + return legend_items_html diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py new file mode 100644 index 000000000..f2dbfb7fd --- /dev/null +++ b/src/crewai/flow/utils.py @@ -0,0 +1,143 @@ +def calculate_node_levels(flow): + levels = {} + queue = [] + visited = set() + pending_and_listeners = {} + + # Make all start methods at level 0 + for method_name, method in flow._methods.items(): + if hasattr(method, "__is_start_method__"): + levels[method_name] = 0 + queue.append(method_name) + + # Breadth-first traversal to assign levels + while queue: + current = queue.pop(0) + current_level = levels[current] + visited.add(current) + + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if condition_type == "OR": + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + elif condition_type == "AND": + if listener_name not in pending_and_listeners: + pending_and_listeners[listener_name] = set() + if current in trigger_methods: + pending_and_listeners[listener_name].add(current) + if set(trigger_methods) == pending_and_listeners[listener_name]: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + # Handle router connections + if current in flow._routers.values(): + router_method_name = current + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + return levels + + +def count_outgoing_edges(flow): + counts = {} + for method_name in flow._methods: + counts[method_name] = 0 + for method_name in flow._listeners: + _, trigger_methods = flow._listeners[method_name] + for trigger in trigger_methods: + if trigger in flow._methods: + counts[trigger] += 1 + return counts + + +def build_ancestor_dict(flow): + ancestors = {node: set() for node in flow._methods} + visited = set() + for node in flow._methods: + if node not in visited: + dfs_ancestors(node, ancestors, visited, flow) + return ancestors + + +def dfs_ancestors(node, ancestors, visited, flow): + if node in visited: + return + visited.add(node) + + # Handle regular listeners + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if node in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + # Handle router methods separately + if node in flow._routers.values(): + router_method_name = node + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + # Only propagate the ancestors of the router method, not the router method itself + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + +def is_ancestor(node, ancestor_candidate, ancestors): + return ancestor_candidate in ancestors.get(node, set()) + + +def build_parent_children_dict(flow): + parent_children = {} + + # Map listeners to their trigger methods + for listener_name, (_, trigger_methods) in flow._listeners.items(): + for trigger in trigger_methods: + if trigger not in parent_children: + parent_children[trigger] = [] + if listener_name not in parent_children[trigger]: + parent_children[trigger].append(listener_name) + + # Map router methods to their paths and to listeners + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + # Map router method to listeners of each path + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + if router_method_name not in parent_children: + parent_children[router_method_name] = [] + if listener_name not in parent_children[router_method_name]: + parent_children[router_method_name].append(listener_name) + + return parent_children + + +def get_child_index(parent, child, parent_children): + children = parent_children.get(parent, []) + children.sort() + return children.index(child) diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py new file mode 100644 index 000000000..ba2ba5f18 --- /dev/null +++ b/src/crewai/flow/visualization_utils.py @@ -0,0 +1,132 @@ +from .utils import ( + build_ancestor_dict, + build_parent_children_dict, + get_child_index, + is_ancestor, +) + + +def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): + level_nodes = {} + node_positions = {} + + for method_name, level in node_levels.items(): + level_nodes.setdefault(level, []).append(method_name) + + for level, nodes in level_nodes.items(): + x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally + for i, method_name in enumerate(nodes): + x = x_offset + i * x_spacing + y = level * y_spacing + node_positions[method_name] = (x, y) + + return node_positions + + +def add_edges(net, flow, node_positions, colors): + ancestors = build_ancestor_dict(flow) + parent_children = build_parent_children_dict(flow) + + for method_name in flow._listeners: + condition_type, trigger_methods = flow._listeners[method_name] + is_and_condition = condition_type == "AND" + + for trigger in trigger_methods: + if trigger in flow._methods or trigger in flow._routers.values(): + is_router_edge = any( + trigger in paths for paths in flow._router_paths.values() + ) + edge_color = colors["router_edge"] if is_router_edge else colors["edge"] + + is_cycle_edge = is_ancestor(trigger, method_name, ancestors) + parent_has_multiple_children = len(parent_children.get(trigger, [])) > 1 + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + source_pos = node_positions.get(trigger) + target_pos = node_positions.get(method_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index(trigger, method_name, parent_children) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} + else: + edge_smooth = False + + edge_style = { + "color": edge_color, + "width": 2, + "arrows": "to", + "dashes": True if is_router_edge or is_and_condition else False, + "smooth": edge_smooth, + } + + net.add_edge(trigger, method_name, **edge_style) + + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + is_cycle_edge = is_ancestor(trigger, method_name, ancestors) + parent_has_multiple_children = ( + len(parent_children.get(router_method_name, [])) > 1 + ) + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + source_pos = node_positions.get(router_method_name) + target_pos = node_positions.get(listener_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index( + router_method_name, listener_name, parent_children + ) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} + else: + edge_smooth = False + + edge_style = { + "color": colors["router_edge"], + "width": 2, + "arrows": "to", + "dashes": True, + "smooth": edge_smooth, + } + net.add_edge(router_method_name, listener_name, **edge_style) + + +def add_nodes_to_network(net, flow, node_positions, node_styles): + for method_name, (x, y) in node_positions.items(): + method = flow._methods.get(method_name) + if hasattr(method, "__is_start_method__"): + node_style = node_styles["start"] + elif hasattr(method, "__is_router__"): + node_style = node_styles["router"] + else: + node_style = node_styles["method"] + + net.add_node( + method_name, + label=method_name, + x=x, + y=y, + fixed=True, + physics=False, + **node_style, + )