mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Restructure
This commit is contained in:
@@ -21,7 +21,6 @@ def create_pipeline(name, router=False):
|
|||||||
(project_root / "src" / folder_name).mkdir(parents=True)
|
(project_root / "src" / folder_name).mkdir(parents=True)
|
||||||
(project_root / "src" / folder_name / "crews").mkdir(parents=True)
|
(project_root / "src" / folder_name / "crews").mkdir(parents=True)
|
||||||
(project_root / "src" / folder_name / "tools").mkdir(parents=True)
|
(project_root / "src" / folder_name / "tools").mkdir(parents=True)
|
||||||
(project_root / "src" / folder_name / "config").mkdir(parents=True)
|
|
||||||
(project_root / "tests").mkdir(exist_ok=True)
|
(project_root / "tests").mkdir(exist_ok=True)
|
||||||
|
|
||||||
# Create .env file
|
# Create .env file
|
||||||
@@ -35,12 +34,8 @@ def create_pipeline(name, router=False):
|
|||||||
# List of template files to copy
|
# List of template files to copy
|
||||||
root_template_files = [".gitignore", "pyproject.toml", "README.md"]
|
root_template_files = [".gitignore", "pyproject.toml", "README.md"]
|
||||||
src_template_files = ["__init__.py", "main.py", "pipeline.py"]
|
src_template_files = ["__init__.py", "main.py", "pipeline.py"]
|
||||||
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
|
tools_template_files = ["tools/__init__.py", "tools/custom_tool.py"]
|
||||||
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
|
crew_folders = ["research_crew", "write_x_crew", "write_linkedin_crew"]
|
||||||
crew_template_files = ["crews/research_crew.py", "crews/write_x_crew.py"]
|
|
||||||
|
|
||||||
if router:
|
|
||||||
crew_template_files.append("crews/write_linkedin_crew.py")
|
|
||||||
|
|
||||||
def process_file(src_file, dst_file):
|
def process_file(src_file, dst_file):
|
||||||
with open(src_file, "r") as file:
|
with open(src_file, "r") as file:
|
||||||
@@ -66,16 +61,22 @@ def create_pipeline(name, router=False):
|
|||||||
dst_file = project_root / "src" / folder_name / file_name
|
dst_file = project_root / "src" / folder_name / file_name
|
||||||
process_file(src_file, dst_file)
|
process_file(src_file, dst_file)
|
||||||
|
|
||||||
# Copy tools and config files
|
# Copy tools files
|
||||||
for file_name in tools_template_files + config_template_files:
|
for file_name in tools_template_files:
|
||||||
src_file = templates_dir / file_name
|
src_file = templates_dir / file_name
|
||||||
dst_file = project_root / "src" / folder_name / file_name
|
dst_file = project_root / "src" / folder_name / file_name
|
||||||
shutil.copy(src_file, dst_file)
|
shutil.copy(src_file, dst_file)
|
||||||
|
|
||||||
# Copy and process crew files
|
# Copy crew folders
|
||||||
for file_name in crew_template_files:
|
for crew_folder in crew_folders:
|
||||||
src_file = templates_dir / file_name
|
src_crew_folder = templates_dir / "crews" / crew_folder
|
||||||
dst_file = project_root / "src" / folder_name / file_name
|
dst_crew_folder = project_root / "src" / folder_name / "crews" / crew_folder
|
||||||
process_file(src_file, dst_file)
|
if src_crew_folder.exists():
|
||||||
|
shutil.copytree(src_crew_folder, dst_crew_folder)
|
||||||
|
else:
|
||||||
|
click.secho(
|
||||||
|
f"Warning: Crew folder {crew_folder} not found in template.",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
click.secho(f"Pipeline {name} created successfully!", fg="green", bold=True)
|
click.secho(f"Pipeline {name} created successfully!", fg="green", bold=True)
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
researcher:
|
||||||
|
role: >
|
||||||
|
{topic} Senior Data Researcher
|
||||||
|
goal: >
|
||||||
|
Uncover cutting-edge developments in {topic}
|
||||||
|
backstory: >
|
||||||
|
You're a seasoned researcher with a knack for uncovering the latest
|
||||||
|
developments in {topic}. Known for your ability to find the most relevant
|
||||||
|
information and present it in a clear and concise manner.
|
||||||
|
|
||||||
|
reporting_analyst:
|
||||||
|
role: >
|
||||||
|
{topic} Reporting Analyst
|
||||||
|
goal: >
|
||||||
|
Create detailed reports based on {topic} data analysis and research findings
|
||||||
|
backstory: >
|
||||||
|
You're a meticulous analyst with a keen eye for detail. You're known for
|
||||||
|
your ability to turn complex data into clear and concise reports, making
|
||||||
|
it easy for others to understand and act on the information you provide.
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
research_task:
|
||||||
|
description: >
|
||||||
|
Conduct a thorough research about {topic}
|
||||||
|
Make sure you find any interesting and relevant information given
|
||||||
|
the current year is 2024.
|
||||||
|
expected_output: >
|
||||||
|
A list with 10 bullet points of the most relevant information about {topic}
|
||||||
|
agent: researcher
|
||||||
|
|
||||||
|
reporting_task:
|
||||||
|
description: >
|
||||||
|
Review the context you got and expand each topic into a full section for a report.
|
||||||
|
Make sure the report is detailed and contains any and all relevant information.
|
||||||
|
expected_output: >
|
||||||
|
A fully fledge reports with a title, mains topics, each with a full section of information.
|
||||||
|
agent: reporting_analyst
|
||||||
@@ -3,7 +3,7 @@ from crewai import Agent, Crew, Process, Task
|
|||||||
from crewai.project import CrewBase, agent, crew, task
|
from crewai.project import CrewBase, agent, crew, task
|
||||||
|
|
||||||
# Uncomment the following line to use an example of a custom tool
|
# Uncomment the following line to use an example of a custom tool
|
||||||
# from {{folder_name}}.tools.custom_tool import MyCustomTool
|
# from demo_pipeline.tools.custom_tool import MyCustomTool
|
||||||
|
|
||||||
# Check our tools documentations for more information on how to use them
|
# Check our tools documentations for more information on how to use them
|
||||||
# from crewai_tools import SerperDevTool
|
# from crewai_tools import SerperDevTool
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
from crewai import Agent, Crew, Process, Task
|
|
||||||
from crewai.project import CrewBase, agent, crew, task
|
|
||||||
|
|
||||||
# Uncomment the following line to use an example of a custom tool
|
|
||||||
# from {{folder_name}}.tools.custom_tool import MyCustomTool
|
|
||||||
|
|
||||||
# Check our tools documentations for more information on how to use them
|
|
||||||
# from crewai_tools import SerperDevTool
|
|
||||||
|
|
||||||
@CrewBase
|
|
||||||
class WriteXCrew():
|
|
||||||
"""Research Crew"""
|
|
||||||
agents_config = 'config/agents.yaml'
|
|
||||||
tasks_config = 'config/tasks.yaml'
|
|
||||||
|
|
||||||
@agent
|
|
||||||
def x_writer_agent(self) -> Agent:
|
|
||||||
return Agent(
|
|
||||||
config=self.agents_config['x_writer_agent'],
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@task
|
|
||||||
def write_x_task(self) -> Task:
|
|
||||||
return Task(
|
|
||||||
config=self.tasks_config['write_x_task'],
|
|
||||||
)
|
|
||||||
|
|
||||||
@crew
|
|
||||||
def crew(self) -> Crew:
|
|
||||||
"""Creates the Write X Crew"""
|
|
||||||
return Crew(
|
|
||||||
agents=self.agents, # Automatically created by the @agent decorator
|
|
||||||
tasks=self.tasks, # Automatically created by the @task decorator
|
|
||||||
process=Process.sequential,
|
|
||||||
verbose=2,
|
|
||||||
)
|
|
||||||
@@ -1,23 +1,3 @@
|
|||||||
researcher:
|
|
||||||
role: >
|
|
||||||
{topic} Senior Data Researcher
|
|
||||||
goal: >
|
|
||||||
Uncover cutting-edge developments in {topic}
|
|
||||||
backstory: >
|
|
||||||
You're a seasoned researcher with a knack for uncovering the latest
|
|
||||||
developments in {topic}. Known for your ability to find the most relevant
|
|
||||||
information and present it in a clear and concise manner.
|
|
||||||
|
|
||||||
reporting_analyst:
|
|
||||||
role: >
|
|
||||||
{topic} Reporting Analyst
|
|
||||||
goal: >
|
|
||||||
Create detailed reports based on {topic} data analysis and research findings
|
|
||||||
backstory: >
|
|
||||||
You're a meticulous analyst with a keen eye for detail. You're known for
|
|
||||||
your ability to turn complex data into clear and concise reports, making
|
|
||||||
it easy for others to understand and act on the information you provide.
|
|
||||||
|
|
||||||
x_writer_agent:
|
x_writer_agent:
|
||||||
role: >
|
role: >
|
||||||
Expert Social Media Content Creator specializing in short form written content
|
Expert Social Media Content Creator specializing in short form written content
|
||||||
@@ -1,20 +1,3 @@
|
|||||||
research_task:
|
|
||||||
description: >
|
|
||||||
Conduct a thorough research about {topic}
|
|
||||||
Make sure you find any interesting and relevant information given
|
|
||||||
the current year is 2024.
|
|
||||||
expected_output: >
|
|
||||||
A list with 10 bullet points of the most relevant information about {topic}
|
|
||||||
agent: researcher
|
|
||||||
|
|
||||||
reporting_task:
|
|
||||||
description: >
|
|
||||||
Review the context you got and expand each topic into a full section for a report.
|
|
||||||
Make sure the report is detailed and contains any and all relevant information.
|
|
||||||
expected_output: >
|
|
||||||
A fully fledge reports with a title, mains topics, each with a full section of information.
|
|
||||||
agent: reporting_analyst
|
|
||||||
|
|
||||||
write_x_task:
|
write_x_task:
|
||||||
description: >
|
description: >
|
||||||
Using the research report provided, create an engaging short form post about {topic}.
|
Using the research report provided, create an engaging short form post about {topic}.
|
||||||
@@ -31,7 +14,7 @@ write_x_task:
|
|||||||
|
|
||||||
Title: {title}
|
Title: {title}
|
||||||
Research:
|
Research:
|
||||||
{research}
|
{body}
|
||||||
|
|
||||||
expected_output: >
|
expected_output: >
|
||||||
A compelling X post under 280 characters that effectively summarizes the key findings
|
A compelling X post under 280 characters that effectively summarizes the key findings
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from crewai import Agent, Crew, Process, Task
|
||||||
|
from crewai.project import CrewBase, agent, crew, task
|
||||||
|
|
||||||
|
# Uncomment the following line to use an example of a custom tool
|
||||||
|
# from demo_pipeline.tools.custom_tool import MyCustomTool
|
||||||
|
|
||||||
|
# Check our tools documentations for more information on how to use them
|
||||||
|
# from crewai_tools import SerperDevTool
|
||||||
|
|
||||||
|
|
||||||
|
@CrewBase
|
||||||
|
class WriteXCrew:
|
||||||
|
"""Research Crew"""
|
||||||
|
|
||||||
|
agents_config = "config/agents.yaml"
|
||||||
|
tasks_config = "config/tasks.yaml"
|
||||||
|
|
||||||
|
@agent
|
||||||
|
def x_writer_agent(self) -> Agent:
|
||||||
|
return Agent(config=self.agents_config["x_writer_agent"], verbose=True)
|
||||||
|
|
||||||
|
@task
|
||||||
|
def write_x_task(self) -> Task:
|
||||||
|
return Task(
|
||||||
|
config=self.tasks_config["write_x_task"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@crew
|
||||||
|
def crew(self) -> Crew:
|
||||||
|
"""Creates the Write X Crew"""
|
||||||
|
return Crew(
|
||||||
|
agents=self.agents, # Automatically created by the @agent decorator
|
||||||
|
tasks=self.tasks, # Automatically created by the @task decorator
|
||||||
|
process=Process.sequential,
|
||||||
|
verbose=2,
|
||||||
|
)
|
||||||
@@ -9,7 +9,7 @@ async def run():
|
|||||||
inputs = [
|
inputs = [
|
||||||
{"topic": "AI wearables"},
|
{"topic": "AI wearables"},
|
||||||
]
|
]
|
||||||
pipeline = {{pipeline_name}}Pipeline().pipeline()
|
pipeline = {{pipeline_name}}Pipeline()
|
||||||
results = await pipeline.kickoff(inputs)
|
results = await pipeline.kickoff(inputs)
|
||||||
|
|
||||||
# Process and print results
|
# Process and print results
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Key features:
|
|||||||
- The ResearchCrew's final task uses output_json to store all research findings in a JSON object.
|
- The ResearchCrew's final task uses output_json to store all research findings in a JSON object.
|
||||||
- This JSON object is then passed to the WriteXCrew, where tasks can access the research findings.
|
- This JSON object is then passed to the WriteXCrew, where tasks can access the research findings.
|
||||||
|
|
||||||
Example 2: Three-Stage Pipeline with Parallel Execution
|
Example 2: Two-Stage Pipeline with Parallel Execution
|
||||||
-------------------------------------------------------
|
-------------------------------------------------------
|
||||||
This pipeline consists of three crews:
|
This pipeline consists of three crews:
|
||||||
1. ResearchCrew: Performs research on a given topic.
|
1. ResearchCrew: Performs research on a given topic.
|
||||||
@@ -28,28 +28,24 @@ Usage:
|
|||||||
|
|
||||||
# Common imports for both examples
|
# Common imports for both examples
|
||||||
from crewai import Pipeline
|
from crewai import Pipeline
|
||||||
from crewai.project.pipeline_base import PipelineBase
|
|
||||||
|
|
||||||
|
|
||||||
from crewai.project.annotations import pipeline
|
|
||||||
|
|
||||||
# Uncomment the crews you need for your chosen example
|
# Uncomment the crews you need for your chosen example
|
||||||
from .crews.research_crew import ResearchCrew
|
from .crews.research_crew.research_crew import ResearchCrew
|
||||||
from .crews.write_x_crew import WriteXCrew
|
from .crews.write_x_crew.write_x_crew import WriteXCrew
|
||||||
# from .crews.write_linkedin_crew import WriteLinkedInCrew # Uncomment for Example 2
|
# from .crews.write_linkedin_crew import WriteLinkedInCrew # Uncomment for Example 2
|
||||||
|
|
||||||
# EXAMPLE 1: Two-Stage Pipeline
|
# EXAMPLE 1: Two-Stage Pipeline
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Uncomment the following code block to use Example 1
|
# Uncomment the following code block to use Example 1
|
||||||
|
|
||||||
@PipelineBase
|
|
||||||
class {{pipeline_name}}Pipeline:
|
class {{pipeline_name}}Pipeline:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Initialize crews
|
# Initialize crews
|
||||||
self.research_crew = ResearchCrew().crew()
|
self.research_crew = ResearchCrew().crew()
|
||||||
self.write_x_crew = WriteXCrew().crew()
|
self.write_x_crew = WriteXCrew().crew()
|
||||||
|
|
||||||
@pipeline
|
|
||||||
def create_pipeline(self):
|
def create_pipeline(self):
|
||||||
return Pipeline(
|
return Pipeline(
|
||||||
stages=[
|
stages=[
|
||||||
@@ -58,7 +54,7 @@ class {{pipeline_name}}Pipeline:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, inputs):
|
async def kickoff(self, inputs):
|
||||||
pipeline = self.create_pipeline()
|
pipeline = self.create_pipeline()
|
||||||
results = await pipeline.kickoff(inputs)
|
results = await pipeline.kickoff(inputs)
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ import click
|
|||||||
|
|
||||||
|
|
||||||
def copy_template(src, dst, name, class_name, folder_name):
|
def copy_template(src, dst, name, class_name, folder_name):
|
||||||
print(f"Copying {src} to {dst}")
|
|
||||||
print(f"Interpolating {name}, {class_name}, {folder_name}")
|
|
||||||
"""Copy a file from src to dst."""
|
"""Copy a file from src to dst."""
|
||||||
with open(src, "r") as file:
|
with open(src, "r") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class Pipeline(BaseModel):
|
|||||||
"""
|
"""
|
||||||
initial_input = copy.deepcopy(kickoff_input)
|
initial_input = copy.deepcopy(kickoff_input)
|
||||||
current_input = copy.deepcopy(kickoff_input)
|
current_input = copy.deepcopy(kickoff_input)
|
||||||
stages = copy.deepcopy(self.stages)
|
stages = self._copy_stages()
|
||||||
pipeline_usage_metrics: Dict[str, UsageMetrics] = {}
|
pipeline_usage_metrics: Dict[str, UsageMetrics] = {}
|
||||||
all_stage_outputs: List[List[CrewOutput]] = []
|
all_stage_outputs: List[List[CrewOutput]] = []
|
||||||
traces: List[List[Union[str, Dict[str, Any]]]] = [[initial_input]]
|
traces: List[List[Union[str, Dict[str, Any]]]] = [[initial_input]]
|
||||||
@@ -151,6 +151,7 @@ class Pipeline(BaseModel):
|
|||||||
while stage_index < len(stages):
|
while stage_index < len(stages):
|
||||||
stage = stages[stage_index]
|
stage = stages[stage_index]
|
||||||
stage_input = copy.deepcopy(current_input)
|
stage_input = copy.deepcopy(current_input)
|
||||||
|
print("stage_input", stage_input)
|
||||||
|
|
||||||
if isinstance(stage, Router):
|
if isinstance(stage, Router):
|
||||||
next_pipeline, route_taken = stage.route(stage_input)
|
next_pipeline, route_taken = stage.route(stage_input)
|
||||||
@@ -164,6 +165,7 @@ class Pipeline(BaseModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
stage_outputs, stage_trace = await self._process_stage(stage, stage_input)
|
stage_outputs, stage_trace = await self._process_stage(stage, stage_input)
|
||||||
|
print("stage_outputs", stage_outputs)
|
||||||
|
|
||||||
self._update_metrics_and_input(
|
self._update_metrics_and_input(
|
||||||
pipeline_usage_metrics, current_input, stage, stage_outputs
|
pipeline_usage_metrics, current_input, stage, stage_outputs
|
||||||
@@ -210,6 +212,8 @@ class Pipeline(BaseModel):
|
|||||||
Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: The output and trace of the crew.
|
Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: The output and trace of the crew.
|
||||||
"""
|
"""
|
||||||
output = await crew.kickoff_async(inputs=current_input)
|
output = await crew.kickoff_async(inputs=current_input)
|
||||||
|
print("output from crew kickoff", output)
|
||||||
|
print("output from crew kickoff dict", output.to_dict())
|
||||||
return [output], [crew.name or str(crew.id)]
|
return [output], [crew.name or str(crew.id)]
|
||||||
|
|
||||||
async def _process_parallel_crews(
|
async def _process_parallel_crews(
|
||||||
@@ -367,6 +371,24 @@ class Pipeline(BaseModel):
|
|||||||
]
|
]
|
||||||
return [crew_outputs + [output] for output in all_stage_outputs[-1]]
|
return [crew_outputs + [output] for output in all_stage_outputs[-1]]
|
||||||
|
|
||||||
|
def _copy_stages(self):
|
||||||
|
"""Create a deep copy of the Pipeline's stages."""
|
||||||
|
new_stages = []
|
||||||
|
for stage in self.stages:
|
||||||
|
if isinstance(stage, list):
|
||||||
|
new_stages.append(
|
||||||
|
[
|
||||||
|
crew.copy() if hasattr(crew, "copy") else copy.deepcopy(crew)
|
||||||
|
for crew in stage
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif hasattr(stage, "copy"):
|
||||||
|
new_stages.append(stage.copy())
|
||||||
|
else:
|
||||||
|
new_stages.append(copy.deepcopy(stage))
|
||||||
|
|
||||||
|
return new_stages
|
||||||
|
|
||||||
def __rshift__(self, other: PipelineStage) -> "Pipeline":
|
def __rshift__(self, other: PipelineStage) -> "Pipeline":
|
||||||
"""
|
"""
|
||||||
Implements the >> operator to add another Stage (Crew or List[Crew]) to an existing Pipeline.
|
Implements the >> operator to add another Stage (Crew or List[Crew]) to an existing Pipeline.
|
||||||
|
|||||||
@@ -24,7 +24,9 @@ def CrewBase(cls):
|
|||||||
original_agents_config_path = getattr(
|
original_agents_config_path = getattr(
|
||||||
cls, "agents_config", "config/agents.yaml"
|
cls, "agents_config", "config/agents.yaml"
|
||||||
)
|
)
|
||||||
|
print("Original agents config path: ", original_agents_config_path)
|
||||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||||
|
print("Original tasks config path: ", original_tasks_config_path)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -39,9 +41,11 @@ def CrewBase(cls):
|
|||||||
self.agents_config = self.load_yaml(
|
self.agents_config = self.load_yaml(
|
||||||
os.path.join(self.base_directory, self.original_agents_config_path)
|
os.path.join(self.base_directory, self.original_agents_config_path)
|
||||||
)
|
)
|
||||||
|
print("Agents config: ", self.agents_config)
|
||||||
self.tasks_config = self.load_yaml(
|
self.tasks_config = self.load_yaml(
|
||||||
os.path.join(self.base_directory, self.original_tasks_config_path)
|
os.path.join(self.base_directory, self.original_tasks_config_path)
|
||||||
)
|
)
|
||||||
|
print("Task config: ", self.tasks_config)
|
||||||
self.map_all_agent_variables()
|
self.map_all_agent_variables()
|
||||||
self.map_all_task_variables()
|
self.map_all_task_variables()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from crewai.pipeline.pipeline import Pipeline
|
|||||||
from crewai.routers.router import Router
|
from crewai.routers.router import Router
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Could potentially remove. Need to check with @joao and @gui if this is needed for CrewAI+
|
||||||
def PipelineBase(cls):
|
def PipelineBase(cls):
|
||||||
class WrappedClass(cls):
|
class WrappedClass(cls):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
@@ -49,9 +50,7 @@ def PipelineBase(cls):
|
|||||||
elif isinstance(stage, list) and all(
|
elif isinstance(stage, list) and all(
|
||||||
isinstance(item, Crew) for item in stage
|
isinstance(item, Crew) for item in stage
|
||||||
):
|
):
|
||||||
self.stages.append(
|
self.stages.append(stage)
|
||||||
[crew_functions[item.__name__]() for item in stage]
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_pipeline(self) -> Pipeline:
|
def build_pipeline(self) -> Pipeline:
|
||||||
return Pipeline(stages=self.stages)
|
return Pipeline(stages=self.stages)
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
from dataclasses import dataclass
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
|
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
T = TypeVar("T", bound=Dict[str, Any])
|
T = TypeVar("T", bound=Dict[str, Any])
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Route(Generic[T, U]):
|
class Route(Generic[T, U]):
|
||||||
condition: Callable[[T], bool]
|
condition: Callable[[T], bool]
|
||||||
pipeline: U
|
pipeline: U
|
||||||
|
|
||||||
|
def __init__(self, condition: Callable[[T], bool], pipeline: U):
|
||||||
|
self.condition = condition
|
||||||
|
self.pipeline = pipeline
|
||||||
|
|
||||||
|
|
||||||
class Router(BaseModel, Generic[T, U]):
|
class Router(BaseModel, Generic[T, U]):
|
||||||
routes: Dict[str, Route[T, U]] = Field(
|
routes: Dict[str, Route[T, U]] = Field(
|
||||||
@@ -19,9 +22,21 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
description="Dictionary of route names to (condition, pipeline) tuples",
|
description="Dictionary of route names to (condition, pipeline) tuples",
|
||||||
)
|
)
|
||||||
default: U = Field(..., description="Default pipeline if no conditions are met")
|
default: U = Field(..., description="Default pipeline if no conditions are met")
|
||||||
|
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
|
def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
|
||||||
super().__init__(routes=routes, default=default, **data)
|
super().__init__(routes=routes, default=default, **data)
|
||||||
|
self._check_copyable(default)
|
||||||
|
for name, route in routes.items():
|
||||||
|
self._check_copyable(route.pipeline)
|
||||||
|
self._route_types[name] = type(route.pipeline)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_copyable(obj):
|
||||||
|
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
|
||||||
|
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")
|
||||||
|
|
||||||
def add_route(
|
def add_route(
|
||||||
self,
|
self,
|
||||||
@@ -40,7 +55,9 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
Returns:
|
Returns:
|
||||||
The Router instance for method chaining
|
The Router instance for method chaining
|
||||||
"""
|
"""
|
||||||
|
self._check_copyable(pipeline)
|
||||||
self.routes[name] = Route(condition=condition, pipeline=pipeline)
|
self.routes[name] = Route(condition=condition, pipeline=pipeline)
|
||||||
|
self._route_types[name] = type(pipeline)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def route(self, input_data: T) -> Tuple[U, str]:
|
def route(self, input_data: T) -> Tuple[U, str]:
|
||||||
@@ -58,3 +75,16 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
return route.pipeline, name
|
return route.pipeline, name
|
||||||
|
|
||||||
return self.default, "default"
|
return self.default, "default"
|
||||||
|
|
||||||
|
def copy(self) -> "Router[T, U]":
|
||||||
|
"""Create a deep copy of the Router."""
|
||||||
|
new_routes = {
|
||||||
|
name: Route(
|
||||||
|
condition=deepcopy(route.condition),
|
||||||
|
pipeline=route.pipeline.copy(),
|
||||||
|
)
|
||||||
|
for name, route in self.routes.items()
|
||||||
|
}
|
||||||
|
new_default = self.default.copy()
|
||||||
|
|
||||||
|
return Router(routes=new_routes, default=new_default)
|
||||||
|
|||||||
Reference in New Issue
Block a user