mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-19 01:12:38 +00:00
Compare commits
2 Commits
devin/1745
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cf09ac7ce | ||
|
|
a36e696a69 |
47
examples/task_decomposition_example.py
Normal file
47
examples/task_decomposition_example.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Example of using task decomposition in CrewAI.
|
||||
|
||||
This example demonstrates how to use the task decomposition feature
|
||||
to break down complex tasks into simpler sub-tasks.
|
||||
|
||||
Feature introduced in CrewAI v1.x.x
|
||||
"""
|
||||
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher with skills in breaking down complex topics.",
|
||||
)
|
||||
|
||||
research_task = Task(
|
||||
description="Research the impact of AI on various industries",
|
||||
expected_output="A comprehensive report covering multiple industries",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
sub_tasks = research_task.decompose(
|
||||
descriptions=[
|
||||
"Research AI impact on healthcare industry",
|
||||
"Research AI impact on finance industry",
|
||||
"Research AI impact on education industry",
|
||||
],
|
||||
expected_outputs=[
|
||||
"A report on AI in healthcare",
|
||||
"A report on AI in finance",
|
||||
"A report on AI in education",
|
||||
],
|
||||
names=["Healthcare", "Finance", "Education"],
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[research_task],
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
print("Final result:", result)
|
||||
|
||||
for i, sub_task in enumerate(research_task.sub_tasks):
|
||||
print(f"Sub-task {i+1} result: {sub_task.output.raw if hasattr(sub_task, 'output') and sub_task.output else 'No output'}")
|
||||
@@ -28,7 +28,7 @@ def create_flow(name):
|
||||
(project_root / "tests").mkdir(exist_ok=True)
|
||||
|
||||
# Create .env file
|
||||
with open(project_root / ".env", "w", encoding="utf-8", newline="\n") as file:
|
||||
with open(project_root / ".env", "w") as file:
|
||||
file.write("OPENAI_API_KEY=YOUR_API_KEY")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
@@ -58,7 +58,7 @@ def create_flow(name):
|
||||
content = content.replace("{{flow_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
with open(dst_file, "w", encoding="utf-8", newline="\n") as file:
|
||||
with open(dst_file, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
# Copy and process root template files
|
||||
|
||||
@@ -138,22 +138,17 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON
|
||||
or if there's an encoding error.
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid or there's an encoding error.
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, "r") as f:
|
||||
return json.load(f)
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading cache file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Try deleting the cache file and trying again.", fg="yellow")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -172,16 +167,13 @@ def fetch_provider_data(cache_file):
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except requests.RequestException as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Unicode decode error when processing provider data: {e}", fg="red")
|
||||
click.secho("This may be due to encoding issues with the downloaded data.", fg="yellow")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -18,24 +18,19 @@ console = Console()
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
"""Copy a file from src to dst."""
|
||||
try:
|
||||
with open(src, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# Interpolate the content
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
# Interpolate the content
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
# Write the interpolated content to the new file
|
||||
with open(dst, "w", encoding="utf-8", newline="\n") as file:
|
||||
file.write(content)
|
||||
# Write the interpolated content to the new file
|
||||
with open(dst, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading template file {src}: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure all template files use UTF-8 encoding.", fg="yellow")
|
||||
raise
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
@@ -83,7 +78,7 @@ def _get_project_attribute(
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r", encoding="utf-8") as f:
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
@@ -124,7 +119,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r", encoding="utf-8") as f:
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
@@ -138,9 +133,6 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading .env file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure the .env file uses UTF-8 encoding.", fg="yellow")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
|
||||
@@ -166,15 +158,10 @@ def tree_find_and_replace(directory, find, replace):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w", encoding="utf-8", newline="\n") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error processing file {filepath}: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Skipping this file.", fg="yellow")
|
||||
continue
|
||||
with open(filepath, "r") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
@@ -202,15 +189,11 @@ def load_env_vars(folder_path):
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
try:
|
||||
with open(env_file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading .env file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure the .env file uses UTF-8 encoding.", fg="yellow")
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
@@ -261,11 +244,6 @@ def write_env_file(folder_path, env_vars):
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
try:
|
||||
with open(env_file_path, "w", encoding="utf-8", newline="\n") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
except Exception as e:
|
||||
click.secho(f"Error writing .env file: {e}", fg="red")
|
||||
click.secho("This may be due to file system permissions or other issues.", fg="yellow")
|
||||
raise
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
ForwardRef,
|
||||
)
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
@@ -137,6 +138,16 @@ class Task(BaseModel):
|
||||
default=0,
|
||||
description="Current number of retries"
|
||||
)
|
||||
parent_task: Optional['Task'] = Field(
|
||||
default=None,
|
||||
description="Parent task that this task was decomposed from.",
|
||||
exclude=True,
|
||||
)
|
||||
sub_tasks: List['Task'] = Field(
|
||||
default_factory=list,
|
||||
description="Sub-tasks that this task was decomposed into.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
@@ -246,13 +257,151 @@ class Task(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
def decompose(
|
||||
self,
|
||||
descriptions: List[str],
|
||||
expected_outputs: Optional[List[str]] = None,
|
||||
names: Optional[List[str]] = None
|
||||
) -> List['Task']:
|
||||
"""
|
||||
Decompose a complex task into simpler sub-tasks.
|
||||
|
||||
Args:
|
||||
descriptions: List of descriptions for each sub-task.
|
||||
expected_outputs: Optional list of expected outputs for each sub-task.
|
||||
names: Optional list of names for each sub-task.
|
||||
|
||||
Returns:
|
||||
List of created sub-tasks.
|
||||
|
||||
Raises:
|
||||
ValueError: If descriptions is empty, or if expected_outputs or names
|
||||
have different lengths than descriptions.
|
||||
|
||||
Side Effects:
|
||||
Modifies self.sub_tasks by adding newly created sub-tasks.
|
||||
"""
|
||||
if not descriptions:
|
||||
raise ValueError("At least one sub-task description is required.")
|
||||
|
||||
if expected_outputs and len(expected_outputs) != len(descriptions):
|
||||
raise ValueError(
|
||||
f"If provided, expected_outputs must have the same length as descriptions. "
|
||||
f"Got {len(expected_outputs)} expected outputs and {len(descriptions)} descriptions."
|
||||
)
|
||||
|
||||
if names and len(names) != len(descriptions):
|
||||
raise ValueError(
|
||||
f"If provided, names must have the same length as descriptions. "
|
||||
f"Got {len(names)} names and {len(descriptions)} descriptions."
|
||||
)
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
sub_task = Task(
|
||||
description=description,
|
||||
expected_output=expected_outputs[i] if expected_outputs else self.expected_output,
|
||||
name=names[i] if names else None,
|
||||
agent=self.agent, # Inherit the agent from the parent task
|
||||
tools=self.tools, # Inherit the tools from the parent task
|
||||
context=[self], # Set the parent task as context for the sub-task
|
||||
parent_task=self, # Reference back to the parent task
|
||||
)
|
||||
self.sub_tasks.append(sub_task)
|
||||
|
||||
return self.sub_tasks
|
||||
|
||||
def combine_sub_task_results(self) -> str:
|
||||
"""
|
||||
Combine the results from all sub-tasks into a single result for this task.
|
||||
|
||||
This method uses the task's agent to intelligently combine the results from
|
||||
all sub-tasks. It requires an agent capable of coherent text summarization
|
||||
and is designed for stateless prompt execution.
|
||||
|
||||
Returns:
|
||||
The combined result as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task has no sub-tasks or no agent assigned.
|
||||
|
||||
Side Effects:
|
||||
None. This method does not modify the task's state.
|
||||
"""
|
||||
if not self.sub_tasks:
|
||||
raise ValueError("Task has no sub-tasks to combine results from.")
|
||||
|
||||
if not self.agent:
|
||||
raise ValueError("Task has no agent to combine sub-task results.")
|
||||
|
||||
sub_task_results = "\n\n".join([
|
||||
f"Sub-task: {sub_task.description}\nResult: {sub_task.output.raw if sub_task.output else 'No result'}"
|
||||
for sub_task in self.sub_tasks
|
||||
])
|
||||
|
||||
combine_prompt = f"""
|
||||
You have completed the following sub-tasks for the main task: "{self.description}"
|
||||
|
||||
{sub_task_results}
|
||||
|
||||
Based on all these sub-tasks, please provide a consolidated final answer for the main task.
|
||||
Expected output format: {self.expected_output if self.expected_output else 'Not specified'}
|
||||
"""
|
||||
|
||||
result = self.agent.execute_task(
|
||||
task=self,
|
||||
context=combine_prompt,
|
||||
tools=self.tools or []
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def execute_sync(
|
||||
self,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task synchronously."""
|
||||
"""
|
||||
Execute the task synchronously.
|
||||
|
||||
If the task has sub-tasks and no output yet, this method will:
|
||||
1. Execute all sub-tasks first
|
||||
2. Combine their results using the agent
|
||||
3. Set the combined result as this task's output
|
||||
|
||||
Args:
|
||||
agent: Optional agent to execute the task with.
|
||||
context: Optional context to pass to the task.
|
||||
tools: Optional tools to pass to the task.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The result of the task execution.
|
||||
|
||||
Side Effects:
|
||||
Sets self.output with the execution result.
|
||||
"""
|
||||
if self.sub_tasks and not self.output:
|
||||
for sub_task in self.sub_tasks:
|
||||
sub_task.execute_sync(
|
||||
agent=sub_task.agent or agent,
|
||||
context=context,
|
||||
tools=sub_task.tools or tools or [],
|
||||
)
|
||||
|
||||
# Combine the results from sub-tasks
|
||||
result = self.combine_sub_task_results()
|
||||
|
||||
self.output = TaskOutput(
|
||||
description=self.description,
|
||||
name=self.name,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
output_format=self.output_format,
|
||||
)
|
||||
|
||||
return self.output
|
||||
|
||||
return self._execute_core(agent, context, tools)
|
||||
|
||||
@property
|
||||
@@ -278,6 +427,55 @@ class Task(BaseModel):
|
||||
).start()
|
||||
return future
|
||||
|
||||
def execute_sub_tasks_async(
|
||||
self,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> List[Future[TaskOutput]]:
|
||||
"""
|
||||
Execute all sub-tasks asynchronously.
|
||||
|
||||
This method starts the execution of all sub-tasks in parallel and returns
|
||||
futures that can be awaited. After all futures are complete, you should call
|
||||
combine_sub_task_results() to aggregate the results.
|
||||
|
||||
Example:
|
||||
```python
|
||||
futures = task.execute_sub_tasks_async()
|
||||
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Combine the results
|
||||
result = task.combine_sub_task_results()
|
||||
```
|
||||
|
||||
Args:
|
||||
agent: Optional agent to execute the sub-tasks with.
|
||||
context: Optional context to pass to the sub-tasks.
|
||||
tools: Optional tools to pass to the sub-tasks.
|
||||
|
||||
Returns:
|
||||
List of futures for the sub-task executions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task has no sub-tasks.
|
||||
"""
|
||||
if not self.sub_tasks:
|
||||
return []
|
||||
|
||||
futures = []
|
||||
for sub_task in self.sub_tasks:
|
||||
future = sub_task.execute_async(
|
||||
agent=sub_task.agent or agent,
|
||||
context=context,
|
||||
tools=sub_task.tools or tools or [],
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
return futures
|
||||
|
||||
def _execute_task_async(
|
||||
self,
|
||||
agent: Optional[BaseAgent],
|
||||
@@ -434,6 +632,8 @@ class Task(BaseModel):
|
||||
"agent",
|
||||
"context",
|
||||
"tools",
|
||||
"parent_task",
|
||||
"sub_tasks",
|
||||
}
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
@@ -457,6 +657,7 @@ class Task(BaseModel):
|
||||
agent=cloned_agent,
|
||||
tools=cloned_tools,
|
||||
)
|
||||
|
||||
|
||||
return copied_task
|
||||
|
||||
@@ -526,3 +727,6 @@ class Task(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
return f"Task(description={self.description}, expected_output={self.expected_output})"
|
||||
|
||||
|
||||
Task.model_rebuild()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
from crewai.cli.cli import create
|
||||
from crewai.cli.create_crew import create_crew
|
||||
|
||||
|
||||
class TestCreateCrew(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.runner = CliRunner()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_dir = Path(self.temp_dir.name)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
@patch("crewai.cli.create_crew.get_provider_data")
|
||||
@patch("crewai.cli.create_crew.select_provider")
|
||||
@patch("crewai.cli.create_crew.select_model")
|
||||
@patch("crewai.cli.create_crew.write_env_file")
|
||||
@patch("crewai.cli.create_crew.load_env_vars")
|
||||
@patch("click.confirm")
|
||||
def test_create_crew_handles_unicode(self, mock_confirm, mock_load_env,
|
||||
mock_write_env, mock_select_model,
|
||||
mock_select_provider, mock_get_provider_data):
|
||||
"""Test that create_crew command handles Unicode properly."""
|
||||
mock_confirm.return_value = True
|
||||
mock_load_env.return_value = {}
|
||||
mock_get_provider_data.return_value = {"openai": ["gpt-4"]}
|
||||
mock_select_provider.return_value = "openai"
|
||||
mock_select_model.return_value = "gpt-4"
|
||||
|
||||
templates_dir = Path("src/crewai/cli/templates/crew")
|
||||
templates_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
template_content = """
|
||||
Hello {{name}}! Unicode test: 你好, こんにちは, Привет 🚀
|
||||
Class: {{crew_name}}
|
||||
Folder: {{folder_name}}
|
||||
"""
|
||||
|
||||
(templates_dir / "tools").mkdir(exist_ok=True)
|
||||
(templates_dir / "config").mkdir(exist_ok=True)
|
||||
|
||||
for file_name in [".gitignore", "pyproject.toml", "README.md", "__init__.py", "main.py", "crew.py"]:
|
||||
with open(templates_dir / file_name, "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
(templates_dir / "knowledge").mkdir(exist_ok=True)
|
||||
with open(templates_dir / "knowledge" / "user_preference.txt", "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
for file_path in ["tools/custom_tool.py", "tools/__init__.py", "config/agents.yaml", "config/tasks.yaml"]:
|
||||
(templates_dir / file_path).parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(templates_dir / file_path, "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
with patch("crewai.cli.create_crew.Path") as mock_path:
|
||||
mock_path.return_value = self.test_dir
|
||||
mock_path.side_effect = lambda x: self.test_dir / x if isinstance(x, str) else x
|
||||
|
||||
create_crew("test_crew", skip_provider=True)
|
||||
|
||||
crew_dir = self.test_dir / "test_crew"
|
||||
for root, _, files in os.walk(crew_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
self.assertIn("你好", content, f"Unicode characters not preserved in {file_path}")
|
||||
self.assertIn("🚀", content, f"Emoji not preserved in {file_path}")
|
||||
@@ -1,89 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.cli.provider import fetch_provider_data, read_cache_file
|
||||
from crewai.cli.utils import (
|
||||
copy_template,
|
||||
load_env_vars,
|
||||
tree_find_and_replace,
|
||||
write_env_file,
|
||||
)
|
||||
|
||||
|
||||
class TestEncoding(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_dir = Path(self.temp_dir.name)
|
||||
|
||||
self.unicode_content = "Hello Unicode: 你好, こんにちは, Привет, مرحبا, 안녕하세요 🚀"
|
||||
self.src_file = self.test_dir / "src_file.txt"
|
||||
self.dst_file = self.test_dir / "dst_file.txt"
|
||||
|
||||
with open(self.src_file, "w", encoding="utf-8") as f:
|
||||
f.write(self.unicode_content)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_copy_template_handles_unicode(self):
|
||||
"""Test that copy_template handles Unicode characters properly in all environments."""
|
||||
copy_template(
|
||||
self.src_file,
|
||||
self.dst_file,
|
||||
"test_name",
|
||||
"TestClass",
|
||||
"test_folder"
|
||||
)
|
||||
|
||||
with open(self.dst_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
self.assertIn("你好", content)
|
||||
self.assertIn("こんにちは", content)
|
||||
self.assertIn("🚀", content)
|
||||
|
||||
def test_env_vars_handle_unicode(self):
|
||||
"""Test that environment variable functions handle Unicode characters properly."""
|
||||
test_env_path = self.test_dir / ".env"
|
||||
test_env_vars = {
|
||||
"KEY1": "Value with Unicode: 你好",
|
||||
"KEY2": "More Unicode: こんにちは 🚀"
|
||||
}
|
||||
|
||||
write_env_file(self.test_dir, test_env_vars)
|
||||
|
||||
loaded_vars = load_env_vars(self.test_dir)
|
||||
|
||||
self.assertEqual(loaded_vars["KEY1"], "Value with Unicode: 你好")
|
||||
self.assertEqual(loaded_vars["KEY2"], "More Unicode: こんにちは 🚀")
|
||||
|
||||
def test_tree_find_and_replace_handles_unicode(self):
|
||||
"""Test that tree_find_and_replace handles Unicode characters properly."""
|
||||
test_file = self.test_dir / "replace_test.txt"
|
||||
with open(test_file, "w", encoding="utf-8") as f:
|
||||
f.write("Replace this: PLACEHOLDER with Unicode: 你好")
|
||||
|
||||
tree_find_and_replace(self.test_dir, "PLACEHOLDER", "🚀")
|
||||
|
||||
with open(test_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
self.assertIn("Replace this: 🚀 with Unicode: 你好", content)
|
||||
|
||||
@patch("crewai.cli.provider.requests.get")
|
||||
def test_provider_functions_handle_unicode(self, mock_get):
|
||||
"""Test that provider data functions handle Unicode properly."""
|
||||
mock_response = unittest.mock.Mock()
|
||||
mock_response.iter_content.return_value = [self.unicode_content.encode("utf-8")]
|
||||
mock_response.headers.get.return_value = str(len(self.unicode_content))
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
cache_file = self.test_dir / "cache.json"
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write('{"model": "Unicode test: 你好 🚀"}')
|
||||
|
||||
cache_data = read_cache_file(cache_file)
|
||||
self.assertEqual(cache_data["model"], "Unicode test: 你好 🚀")
|
||||
157
tests/test_task_decomposition.py
Normal file
157
tests/test_task_decomposition.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai import Agent, Task
|
||||
|
||||
|
||||
def test_task_decomposition_structure():
|
||||
"""Test that task decomposition creates the proper parent-child relationship."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI on various industries",
|
||||
expected_output="A comprehensive report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
sub_task_descriptions = [
|
||||
"Research AI impact on healthcare",
|
||||
"Research AI impact on finance",
|
||||
"Research AI impact on education",
|
||||
]
|
||||
|
||||
sub_tasks = parent_task.decompose(
|
||||
descriptions=sub_task_descriptions,
|
||||
expected_outputs=["Healthcare report", "Finance report", "Education report"],
|
||||
names=["Healthcare", "Finance", "Education"],
|
||||
)
|
||||
|
||||
assert len(sub_tasks) == 3
|
||||
assert len(parent_task.sub_tasks) == 3
|
||||
|
||||
for sub_task in sub_tasks:
|
||||
assert sub_task.parent_task == parent_task
|
||||
assert parent_task in sub_task.context
|
||||
|
||||
|
||||
def test_task_execution_with_sub_tasks():
|
||||
"""Test that executing a task with sub-tasks executes the sub-tasks first."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI on various industries",
|
||||
expected_output="A comprehensive report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
sub_task_descriptions = [
|
||||
"Research AI impact on healthcare",
|
||||
"Research AI impact on finance",
|
||||
"Research AI impact on education",
|
||||
]
|
||||
|
||||
parent_task.decompose(
|
||||
descriptions=sub_task_descriptions,
|
||||
expected_outputs=["Healthcare report", "Finance report", "Education report"],
|
||||
)
|
||||
|
||||
with patch.object(Agent, 'execute_task', return_value="Mock result") as mock_execute_task:
|
||||
result = parent_task.execute_sync()
|
||||
|
||||
assert mock_execute_task.call_count >= 3
|
||||
|
||||
for sub_task in parent_task.sub_tasks:
|
||||
assert sub_task.output is not None
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
|
||||
def test_combine_sub_task_results():
|
||||
"""Test that combining sub-task results works correctly."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI on various industries",
|
||||
expected_output="A comprehensive report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
sub_tasks = parent_task.decompose([
|
||||
"Research AI impact on healthcare",
|
||||
"Research AI impact on finance",
|
||||
])
|
||||
|
||||
for sub_task in sub_tasks:
|
||||
sub_task.output = Mock()
|
||||
sub_task.output.raw = f"Result for {sub_task.description}"
|
||||
|
||||
with patch.object(Agent, 'execute_task', return_value="Combined result") as mock_execute_task:
|
||||
result = parent_task.combine_sub_task_results()
|
||||
|
||||
assert mock_execute_task.called
|
||||
assert result == "Combined result"
|
||||
|
||||
|
||||
def test_task_decomposition_validation():
|
||||
"""Test that task decomposition validates inputs correctly."""
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="At least one sub-task description is required"):
|
||||
parent_task.decompose([])
|
||||
|
||||
with pytest.raises(ValueError, match="expected_outputs must have the same length"):
|
||||
parent_task.decompose(
|
||||
["Task 1", "Task 2"],
|
||||
expected_outputs=["Output 1"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="names must have the same length"):
|
||||
parent_task.decompose(
|
||||
["Task 1", "Task 2"],
|
||||
names=["Name 1"]
|
||||
)
|
||||
|
||||
|
||||
def test_execute_sub_tasks_async():
|
||||
"""Test that executing sub-tasks asynchronously works correctly."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI on various industries",
|
||||
expected_output="A comprehensive report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
sub_tasks = parent_task.decompose([
|
||||
"Research AI impact on healthcare",
|
||||
"Research AI impact on finance",
|
||||
])
|
||||
|
||||
with patch.object(Task, 'execute_async') as mock_execute_async:
|
||||
mock_future = Mock()
|
||||
mock_execute_async.return_value = mock_future
|
||||
|
||||
futures = parent_task.execute_sub_tasks_async()
|
||||
|
||||
assert mock_execute_async.call_count == 2
|
||||
assert len(futures) == 2
|
||||
109
tests/test_task_decomposition_edge_cases.py
Normal file
109
tests/test_task_decomposition_edge_cases.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai import Agent, Task, TaskOutput
|
||||
|
||||
|
||||
def test_combine_sub_task_results_no_sub_tasks():
|
||||
"""Test that combining sub-task results raises an error when there are no sub-tasks."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Task has no sub-tasks to combine results from"):
|
||||
parent_task.combine_sub_task_results()
|
||||
|
||||
|
||||
def test_combine_sub_task_results_no_agent():
|
||||
"""Test that combining sub-task results raises an error when there is no agent."""
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
)
|
||||
|
||||
sub_task = Task(
|
||||
description="Research AI impact on healthcare",
|
||||
expected_output="Healthcare report",
|
||||
parent_task=parent_task,
|
||||
)
|
||||
parent_task.sub_tasks.append(sub_task)
|
||||
|
||||
with pytest.raises(ValueError, match="Task has no agent to combine sub-task results"):
|
||||
parent_task.combine_sub_task_results()
|
||||
|
||||
|
||||
def test_execute_sync_sets_output_after_combining():
|
||||
"""Test that execute_sync sets the output after combining sub-task results."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
sub_tasks = parent_task.decompose([
|
||||
"Research AI impact on healthcare",
|
||||
"Research AI impact on finance",
|
||||
])
|
||||
|
||||
with patch.object(Agent, 'execute_task', return_value="Combined result") as mock_execute_task:
|
||||
result = parent_task.execute_sync()
|
||||
|
||||
assert parent_task.output is not None
|
||||
assert parent_task.output.raw == "Combined result"
|
||||
assert result.raw == "Combined result"
|
||||
|
||||
assert mock_execute_task.call_count >= 3
|
||||
|
||||
|
||||
def test_deep_cloning_prevents_shared_state():
|
||||
"""Test that deep cloning prevents shared mutable state between tasks."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research effectively",
|
||||
backstory="You're an expert researcher",
|
||||
)
|
||||
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
copied_task = parent_task.copy()
|
||||
|
||||
copied_task.description = "Modified description"
|
||||
|
||||
assert parent_task.description == "Research the impact of AI"
|
||||
assert copied_task.description == "Modified description"
|
||||
|
||||
parent_task.decompose(["Sub-task 1", "Sub-task 2"])
|
||||
|
||||
assert len(parent_task.sub_tasks) == 2
|
||||
assert len(copied_task.sub_tasks) == 0
|
||||
|
||||
|
||||
def test_execute_sub_tasks_async_empty_sub_tasks():
|
||||
"""Test that execute_sub_tasks_async returns an empty list when there are no sub-tasks."""
|
||||
parent_task = Task(
|
||||
description="Research the impact of AI",
|
||||
expected_output="A report",
|
||||
)
|
||||
|
||||
futures = parent_task.execute_sub_tasks_async()
|
||||
|
||||
assert isinstance(futures, list)
|
||||
assert len(futures) == 0
|
||||
Reference in New Issue
Block a user