mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
WIP. Figuring out disconnect issue.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
@@ -55,7 +55,7 @@ class Agent(BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = TokenProcess()
|
||||
_token_process: TokenProcess = PrivateAttr(default=TokenProcess())
|
||||
|
||||
formatting_errors: int = 0
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -98,8 +98,7 @@ class Agent(BaseModel):
|
||||
agent_executor: InstanceOf[CrewAgentExecutor] = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
crew: Any = Field(
|
||||
default=None, description="Crew to which the agent belongs.")
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
)
|
||||
@@ -110,8 +109,7 @@ class Agent(BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
i18n: I18N = Field(
|
||||
default=I18N(), description="Internationalization settings.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
llm: Any = Field(
|
||||
default_factory=lambda: ChatOpenAI(
|
||||
model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o")
|
||||
@@ -172,8 +170,8 @@ class Agent(BaseModel):
|
||||
def set_agent_executor(self) -> "Agent":
|
||||
"""set agent executor is set."""
|
||||
if hasattr(self.llm, "model_name"):
|
||||
token_handler = TokenCalcHandler(
|
||||
self.llm.model_name, self._token_process)
|
||||
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
print("TOKENHANDLER UUID", token_handler.id)
|
||||
|
||||
# Ensure self.llm.callbacks is a list
|
||||
if not isinstance(self.llm.callbacks, list):
|
||||
@@ -183,6 +181,10 @@ class Agent(BaseModel):
|
||||
if not any(
|
||||
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
|
||||
):
|
||||
print(
|
||||
"IMPORTANT: TokenCalcHandler not found in callbacks. Adding",
|
||||
token_handler.id,
|
||||
)
|
||||
self.llm.callbacks.append(token_handler)
|
||||
|
||||
if not self.agent_executor:
|
||||
@@ -236,8 +238,7 @@ class Agent(BaseModel):
|
||||
self.agent_executor.tools = parsed_tools
|
||||
self.agent_executor.task = task
|
||||
|
||||
self.agent_executor.tools_description = render_text_description(
|
||||
parsed_tools)
|
||||
self.agent_executor.tools_description = render_text_description(parsed_tools)
|
||||
self.agent_executor.tools_names = self.__tools_names(parsed_tools)
|
||||
|
||||
result = self.agent_executor.invoke(
|
||||
@@ -335,8 +336,7 @@ class Agent(BaseModel):
|
||||
)
|
||||
|
||||
bind = self.llm.bind(stop=stop_words)
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(
|
||||
agent=self)
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
||||
)
|
||||
@@ -371,7 +371,7 @@ class Agent(BaseModel):
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
|
||||
def copy(self):
|
||||
"""Create a deep copy of the Agent."""
|
||||
exclude = {
|
||||
@@ -379,8 +379,8 @@ class Agent(BaseModel):
|
||||
"_logger",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"_token_process",
|
||||
"agent_executor",
|
||||
"_token_process",
|
||||
"agent_executor",
|
||||
"tools",
|
||||
"tools_handler",
|
||||
"cache_handler",
|
||||
|
||||
@@ -279,28 +279,37 @@ class Crew(BaseModel):
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
)
|
||||
|
||||
print("FINISHED EXECUTION OF CREW", self.id)
|
||||
print("GOING TO INVESTIGATE TOKEN USAGE")
|
||||
for agent in self.agents:
|
||||
print("ANALYZING AGENT", agent.id)
|
||||
print("AGENT _token_process id: ", agent._token_process.id)
|
||||
|
||||
print("AGENT USAGE METRICS", agent._token_process.get_summary())
|
||||
|
||||
# TODO: THIS IS A BUG. ONLY THE LAST AGENT'S TOKEN USAGE IS BEING RETURNED
|
||||
metrics = metrics + [
|
||||
agent._token_process.get_summary() for agent in self.agents
|
||||
]
|
||||
print()
|
||||
self.usage_metrics = {
|
||||
key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0]
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List:
|
||||
def kickoff_for_each(
|
||||
self, inputs: List[Dict[str, Any]]
|
||||
) -> List[Union[str, Dict[str, Any]]]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results = []
|
||||
|
||||
for input_data in inputs:
|
||||
crew = self.copy()
|
||||
|
||||
for task in crew.tasks:
|
||||
task.interpolate_inputs(input_data)
|
||||
for agent in crew.agents:
|
||||
agent.interpolate_inputs(input_data)
|
||||
|
||||
output = crew.kickoff()
|
||||
output = crew.kickoff(inputs=input_data)
|
||||
# TODO: FIGURE OUT HOW TO MERGE THE USAGE METRICS
|
||||
# TODO: I would expect we would want to merge the usage metrics from each crew execution
|
||||
results.append(output)
|
||||
|
||||
return results
|
||||
@@ -315,17 +324,15 @@ class Crew(BaseModel):
|
||||
async def run_crew(input_data):
|
||||
crew = self.copy()
|
||||
|
||||
for task in crew.tasks:
|
||||
task.interpolate_inputs(input_data)
|
||||
for agent in crew.agents:
|
||||
agent.interpolate_inputs(input_data)
|
||||
|
||||
return await crew.kickoff_async()
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# TODO: FIGURE OUT HOW TO MERGE THE USAGE METRICS
|
||||
# TODO: I would expect we would want to merge the usage metrics from each crew execution
|
||||
|
||||
return results
|
||||
|
||||
def train(self, n_iterations: int) -> None:
|
||||
@@ -335,6 +342,13 @@ class Crew(BaseModel):
|
||||
def _run_sequential_process(self) -> Union[str, Dict[str, Any]]:
|
||||
"""Executes tasks sequentially and returns the final output."""
|
||||
task_output = ""
|
||||
total_token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
}
|
||||
|
||||
for task in self.tasks:
|
||||
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
|
||||
agents_for_delegation = [
|
||||
@@ -365,11 +379,19 @@ class Crew(BaseModel):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(agent=role, task=task_output, status="completed")
|
||||
|
||||
# Update token usage for the current task
|
||||
current_token_usage = task.agent._token_process.get_summary()
|
||||
for key in total_token_usage:
|
||||
total_token_usage[key] += current_token_usage.get(key, 0)
|
||||
|
||||
print("Updated total_token_usage:", total_token_usage)
|
||||
|
||||
self._finish_execution(task_output)
|
||||
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process"
|
||||
token_usage = task.agent._token_process.get_summary()
|
||||
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
|
||||
|
||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||
return self._format_output(task_output, token_usage)
|
||||
# TODO: TEST AND FIX
|
||||
return self._format_output(task_output, total_token_usage)
|
||||
|
||||
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
|
||||
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
|
||||
@@ -415,9 +437,11 @@ class Crew(BaseModel):
|
||||
self._finish_execution(task_output)
|
||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||
manager_token_usage = manager._token_process.get_summary()
|
||||
return self._format_output(
|
||||
task_output, manager_token_usage
|
||||
), manager_token_usage
|
||||
# TODO: TEST AND FIX
|
||||
return (
|
||||
self._format_output(task_output, manager_token_usage),
|
||||
manager_token_usage,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
"""Create a deep copy of the Crew."""
|
||||
@@ -432,12 +456,18 @@ class Crew(BaseModel):
|
||||
"_short_term_memory",
|
||||
"_long_term_memory",
|
||||
"_entity_memory",
|
||||
"_telemetry",
|
||||
"agents",
|
||||
"tasks",
|
||||
}
|
||||
|
||||
print("CREW ID", self.id)
|
||||
print("CURRENT IDS FOR AGENTS", [agent.id for agent in self.agents])
|
||||
print("CURRENT IDS FOR TASKS", [task.id for task in self.tasks])
|
||||
|
||||
# TODO: I think there is a disconnect. We need to pass new agents and tasks to the new crew
|
||||
cloned_agents = [agent.copy() for agent in self.agents]
|
||||
cloned_tasks = [task.copy() for task in self.tasks]
|
||||
cloned_tasks = [task.copy(cloned_agents) for task in self.tasks]
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
@@ -447,6 +477,12 @@ class Crew(BaseModel):
|
||||
|
||||
copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks)
|
||||
|
||||
print("COPIED CREW ID", copied_crew.id)
|
||||
print("NEW IDS FOR AGENTS", [agent.id for agent in copied_crew.agents])
|
||||
print("NEW IDS FOR TASKS", [task.id for task in copied_crew.tasks])
|
||||
|
||||
# TODO: EXPERIMENT, PRINT ID'S AND MAKE SURE I'M CALLING THE RIGHT AGENTS AND TASKS
|
||||
|
||||
return copied_crew
|
||||
|
||||
def _set_tasks_callbacks(self) -> None:
|
||||
@@ -475,6 +511,7 @@ class Crew(BaseModel):
|
||||
If full_output is True, then returned data type will be a dictionary else returned outputs are string
|
||||
"""
|
||||
if self.full_output:
|
||||
print("SPITTING OUT FULL OUTPUT FOR CREW", self.id)
|
||||
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
|
||||
"final_output": output,
|
||||
"tasks_outputs": [task.output for task in self.tasks if task],
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -192,13 +192,19 @@ class Task(BaseModel):
|
||||
)
|
||||
return result
|
||||
|
||||
def _execute(self, agent, task, context, tools):
|
||||
def _execute(self, agent: Agent, task, context, tools):
|
||||
result = agent.execute_task(
|
||||
task=task,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
print("CALLING EXECUTE ON TASK WITH ID", task.id)
|
||||
print("THIS TASK IS CALLING AGENT", agent.id)
|
||||
print(
|
||||
"CALLING TOKEN PROCESS in task on AGENT", agent._token_process.get_summary()
|
||||
)
|
||||
|
||||
exported_output = self._export_output(result)
|
||||
|
||||
self.output = TaskOutput(
|
||||
@@ -246,7 +252,7 @@ class Task(BaseModel):
|
||||
"""Increment the delegations counter."""
|
||||
self.delegations += 1
|
||||
|
||||
def copy(self):
|
||||
def copy(self, agents: Optional[List[Agent]] = None) -> "Task":
|
||||
"""Create a deep copy of the Task."""
|
||||
exclude = {
|
||||
"id",
|
||||
@@ -254,6 +260,7 @@ class Task(BaseModel):
|
||||
"context",
|
||||
"tools",
|
||||
}
|
||||
print("ORIGINAL TOOLS:", self.tools)
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
@@ -261,8 +268,16 @@ class Task(BaseModel):
|
||||
cloned_context = (
|
||||
[task.copy() for task in self.context] if self.context else None
|
||||
)
|
||||
cloned_agent = self.agent.copy() if self.agent else None
|
||||
cloned_tools = deepcopy(self.tools) if self.tools else None
|
||||
|
||||
# TODO: Make sure this clone approach is correct.
|
||||
def get_agent_by_role(role: str) -> Agent | None:
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
# cloned_agent = self.agent.copy() if self.agent else None
|
||||
cloned_tools = deepcopy(self.tools) if self.tools else []
|
||||
|
||||
print("CLONED_TOOLS", cloned_tools)
|
||||
|
||||
copied_task = Task(
|
||||
**copied_data,
|
||||
@@ -270,6 +285,8 @@ class Task(BaseModel):
|
||||
agent=cloned_agent,
|
||||
tools=cloned_tools,
|
||||
)
|
||||
|
||||
print("TASK TOOLS:", copied_task.tools)
|
||||
return copied_task
|
||||
|
||||
def _export_output(self, result: str) -> Any:
|
||||
@@ -328,7 +345,9 @@ class Task(BaseModel):
|
||||
if self.output_file:
|
||||
content = (
|
||||
# type: ignore # "str" has no attribute "json"
|
||||
exported_result if not self.output_pydantic else exported_result.json()
|
||||
exported_result
|
||||
if not self.output_pydantic
|
||||
else exported_result.model_dump_json()
|
||||
)
|
||||
self._save_file(content)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tiktoken
|
||||
@@ -6,6 +7,7 @@ from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
id = uuid.uuid4() # TODO: REMOVE THIS
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
@@ -32,6 +34,7 @@ class TokenProcess:
|
||||
|
||||
|
||||
class TokenCalcHandler(BaseCallbackHandler):
|
||||
id = uuid.uuid4() # TODO: REMOVE THIS
|
||||
model: str = ""
|
||||
token_cost_process: TokenProcess
|
||||
|
||||
|
||||
400750
tests/cassettes/test_crew_kickoff_for_each_full_ouput.yaml
Normal file
400750
tests/cassettes/test_crew_kickoff_for_each_full_ouput.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -391,6 +391,83 @@ def test_crew_full_ouput():
|
||||
"total_tokens": 439,
|
||||
},
|
||||
}
|
||||
assert False
|
||||
|
||||
|
||||
"""
|
||||
Issues:
|
||||
- Each output is not tracking usage metrics
|
||||
"""
|
||||
|
||||
|
||||
# @pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_for_each_full_ouput():
|
||||
# TODO: Add docstrings to all tests
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
# {"topic": "cat"},
|
||||
# {"topic": "apple"},
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
"Dogs are loyal companions and popular pets.",
|
||||
"Cats are independent and low-maintenance pets.",
|
||||
"Apples are a rich source of dietary fiber and vitamin C.",
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], full_output=True)
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
# with patch.object(Agent, "execute_task") as mock_execute_task:
|
||||
# mock_execute_task.side_effect = expected_outputs
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
print("RESULTS:", results)
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_async_kickoff_for_each_full_ouput():
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
"Dogs are loyal companions and popular pets.",
|
||||
"Cats are independent and low-maintenance pets.",
|
||||
"Apples are a rich source of dietary fiber and vitamin C.",
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task1, task2], full_output=True)
|
||||
assert False
|
||||
|
||||
|
||||
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
||||
@@ -465,6 +542,266 @@ def test_async_task_execution():
|
||||
join.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_single_input():
|
||||
"""Tests if kickoff_for_each works with a single input."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [{"topic": "dog"}]
|
||||
expected_outputs = ["Dogs are loyal companions and popular pets."]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
with patch.object(Agent, "execute_task") as mock_execute_task:
|
||||
mock_execute_task.side_effect = expected_outputs
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results == expected_outputs
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_multiple_inputs():
|
||||
"""Tests if kickoff_for_each works with multiple inputs."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
expected_outputs = [
|
||||
"Dogs are loyal companions and popular pets.",
|
||||
"Cats are independent and low-maintenance pets.",
|
||||
"Apples are a rich source of dietary fiber and vitamin C.",
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
with patch.object(Agent, "execute_task") as mock_execute_task:
|
||||
mock_execute_task.side_effect = expected_outputs
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
for i, res in enumerate(results):
|
||||
assert res == expected_outputs[i]
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_empty_input():
|
||||
"""Tests if kickoff_for_each handles an empty input list."""
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=[])
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_invalid_input():
|
||||
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pass a string instead of a list
|
||||
crew.kickoff_for_each("invalid input")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_error_handling():
|
||||
"""Tests error handling in kickoff_for_each when kickoff raises an error."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
expected_outputs = [
|
||||
"Dogs are loyal companions and popular pets.",
|
||||
"Cats are independent and low-maintenance pets.",
|
||||
"Apples are a rich source of dietary fiber and vitamin C.",
|
||||
]
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(Crew, "kickoff") as mock_kickoff:
|
||||
mock_kickoff.side_effect = expected_outputs[:2] + [
|
||||
Exception("Simulated kickoff error")
|
||||
]
|
||||
with pytest.raises(Exception, match="Simulated kickoff error"):
|
||||
crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_async_basic_functionality_and_output():
|
||||
"""Tests the basic functionality and output of kickoff_async."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = {"topic": "dog"}
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create the crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
expected_output = "This is a sample output from kickoff."
|
||||
with patch.object(Crew, "kickoff", return_value=expected_output) as mock_kickoff:
|
||||
result = await crew.kickoff_async(inputs)
|
||||
|
||||
assert isinstance(result, str), "Result should be a string"
|
||||
assert result == expected_output, "Result should match expected output"
|
||||
mock_kickoff.assert_called_once_with(inputs)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.asyncio # Use pytest-asyncio for async tests
|
||||
async def test_async_kickoff_for_each_async_basic_functionality_and_output():
|
||||
"""Tests the basic functionality and output of akickoff_for_each_async."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
# Define expected outputs for each input
|
||||
expected_outputs = [
|
||||
"Dogs are loyal companions and popular pets.",
|
||||
"Cats are independent and low-maintenance pets.",
|
||||
"Apples are a rich source of dietary fiber and vitamin C.",
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
Crew, "kickoff_async", side_effect=expected_outputs
|
||||
) as mock_kickoff_async:
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
results = await crew.kickoff_for_each_async(inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
assert results == expected_outputs
|
||||
for input_data in inputs:
|
||||
mock_kickoff_async.assert_any_call(inputs=input_data)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_kickoff_for_each_async_empty_input():
|
||||
"""Tests if akickoff_for_each_async handles an empty input list."""
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create the crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Call the function we are testing
|
||||
results = await crew.kickoff_for_each_async([])
|
||||
|
||||
# Assertion
|
||||
assert results == [], "Result should be an empty list when input is empty"
|
||||
|
||||
|
||||
# TODO: TEST KICKOFF FOR EACH WITH USAGE METRICS
|
||||
# TODO: TEST ASYNC KICKOFF FOR EACH WITH USAGE METRICS
|
||||
|
||||
|
||||
def test_set_agents_step_callback():
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user