WIP. Figuring out disconnect issue.

This commit is contained in:
Brandon Hancock
2024-06-25 15:23:32 -07:00
parent c4d76cde8f
commit cc1c97e87d
6 changed files with 401187 additions and 41 deletions

View File

@@ -1,6 +1,6 @@
from copy import deepcopy
import os import os
import uuid import uuid
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent from langchain.agents.agent import RunnableAgent
@@ -55,7 +55,7 @@ class Agent(BaseModel):
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None) _rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None)
_token_process: TokenProcess = TokenProcess() _token_process: TokenProcess = PrivateAttr(default=TokenProcess())
formatting_errors: int = 0 formatting_errors: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -98,8 +98,7 @@ class Agent(BaseModel):
agent_executor: InstanceOf[CrewAgentExecutor] = Field( agent_executor: InstanceOf[CrewAgentExecutor] = Field(
default=None, description="An instance of the CrewAgentExecutor class." default=None, description="An instance of the CrewAgentExecutor class."
) )
crew: Any = Field( crew: Any = Field(default=None, description="Crew to which the agent belongs.")
default=None, description="Crew to which the agent belongs.")
tools_handler: InstanceOf[ToolsHandler] = Field( tools_handler: InstanceOf[ToolsHandler] = Field(
default=None, description="An instance of the ToolsHandler class." default=None, description="An instance of the ToolsHandler class."
) )
@@ -110,8 +109,7 @@ class Agent(BaseModel):
default=None, default=None,
description="Callback to be executed after each step of the agent execution.", description="Callback to be executed after each step of the agent execution.",
) )
i18n: I18N = Field( i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
default=I18N(), description="Internationalization settings.")
llm: Any = Field( llm: Any = Field(
default_factory=lambda: ChatOpenAI( default_factory=lambda: ChatOpenAI(
model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o") model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o")
@@ -172,8 +170,8 @@ class Agent(BaseModel):
def set_agent_executor(self) -> "Agent": def set_agent_executor(self) -> "Agent":
"""set agent executor is set.""" """set agent executor is set."""
if hasattr(self.llm, "model_name"): if hasattr(self.llm, "model_name"):
token_handler = TokenCalcHandler( token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
self.llm.model_name, self._token_process) print("TOKENHANDLER UUID", token_handler.id)
# Ensure self.llm.callbacks is a list # Ensure self.llm.callbacks is a list
if not isinstance(self.llm.callbacks, list): if not isinstance(self.llm.callbacks, list):
@@ -183,6 +181,10 @@ class Agent(BaseModel):
if not any( if not any(
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
): ):
print(
"IMPORTANT: TokenCalcHandler not found in callbacks. Adding",
token_handler.id,
)
self.llm.callbacks.append(token_handler) self.llm.callbacks.append(token_handler)
if not self.agent_executor: if not self.agent_executor:
@@ -236,8 +238,7 @@ class Agent(BaseModel):
self.agent_executor.tools = parsed_tools self.agent_executor.tools = parsed_tools
self.agent_executor.task = task self.agent_executor.task = task
self.agent_executor.tools_description = render_text_description( self.agent_executor.tools_description = render_text_description(parsed_tools)
parsed_tools)
self.agent_executor.tools_names = self.__tools_names(parsed_tools) self.agent_executor.tools_names = self.__tools_names(parsed_tools)
result = self.agent_executor.invoke( result = self.agent_executor.invoke(
@@ -335,8 +336,7 @@ class Agent(BaseModel):
) )
bind = self.llm.bind(stop=stop_words) bind = self.llm.bind(stop=stop_words)
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser( inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
agent=self)
self.agent_executor = CrewAgentExecutor( self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args agent=RunnableAgent(runnable=inner_agent), **executor_args
) )
@@ -371,7 +371,7 @@ class Agent(BaseModel):
thoughts += action.log thoughts += action.log
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}" thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts return thoughts
def copy(self): def copy(self):
"""Create a deep copy of the Agent.""" """Create a deep copy of the Agent."""
exclude = { exclude = {
@@ -379,8 +379,8 @@ class Agent(BaseModel):
"_logger", "_logger",
"_rpm_controller", "_rpm_controller",
"_request_within_rpm_limit", "_request_within_rpm_limit",
"_token_process", "_token_process",
"agent_executor", "agent_executor",
"tools", "tools",
"tools_handler", "tools_handler",
"cache_handler", "cache_handler",

View File

@@ -279,28 +279,37 @@ class Crew(BaseModel):
f"The process '{self.process}' is not implemented yet." f"The process '{self.process}' is not implemented yet."
) )
print("FINISHED EXECUTION OF CREW", self.id)
print("GOING TO INVESTIGATE TOKEN USAGE")
for agent in self.agents:
print("ANALYZING AGENT", agent.id)
print("AGENT _token_process id: ", agent._token_process.id)
print("AGENT USAGE METRICS", agent._token_process.get_summary())
# TODO: THIS IS A BUG. ONLY THE LAST AGENT'S TOKEN USAGE IS BEING RETURNED
metrics = metrics + [ metrics = metrics + [
agent._token_process.get_summary() for agent in self.agents agent._token_process.get_summary() for agent in self.agents
] ]
print()
self.usage_metrics = { self.usage_metrics = {
key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0] key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0]
} }
return result return result
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List: def kickoff_for_each(
self, inputs: List[Dict[str, Any]]
) -> List[Union[str, Dict[str, Any]]]:
"""Executes the Crew's workflow for each input in the list and aggregates results.""" """Executes the Crew's workflow for each input in the list and aggregates results."""
results = [] results = []
for input_data in inputs: for input_data in inputs:
crew = self.copy() crew = self.copy()
for task in crew.tasks: output = crew.kickoff(inputs=input_data)
task.interpolate_inputs(input_data) # TODO: FIGURE OUT HOW TO MERGE THE USAGE METRICS
for agent in crew.agents: # TODO: I would expect we would want to merge the usage metrics from each crew execution
agent.interpolate_inputs(input_data)
output = crew.kickoff()
results.append(output) results.append(output)
return results return results
@@ -315,17 +324,15 @@ class Crew(BaseModel):
async def run_crew(input_data): async def run_crew(input_data):
crew = self.copy() crew = self.copy()
for task in crew.tasks: return await crew.kickoff_async(inputs=input_data)
task.interpolate_inputs(input_data)
for agent in crew.agents:
agent.interpolate_inputs(input_data)
return await crew.kickoff_async()
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs] tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
# TODO: FIGURE OUT HOW TO MERGE THE USAGE METRICS
# TODO: I would expect we would want to merge the usage metrics from each crew execution
return results return results
def train(self, n_iterations: int) -> None: def train(self, n_iterations: int) -> None:
@@ -335,6 +342,13 @@ class Crew(BaseModel):
def _run_sequential_process(self) -> Union[str, Dict[str, Any]]: def _run_sequential_process(self) -> Union[str, Dict[str, Any]]:
"""Executes tasks sequentially and returns the final output.""" """Executes tasks sequentially and returns the final output."""
task_output = "" task_output = ""
total_token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
}
for task in self.tasks: for task in self.tasks:
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation" if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
agents_for_delegation = [ agents_for_delegation = [
@@ -365,11 +379,19 @@ class Crew(BaseModel):
if self.output_log_file: if self.output_log_file:
self._file_handler.log(agent=role, task=task_output, status="completed") self._file_handler.log(agent=role, task=task_output, status="completed")
# Update token usage for the current task
current_token_usage = task.agent._token_process.get_summary()
for key in total_token_usage:
total_token_usage[key] += current_token_usage.get(key, 0)
print("Updated total_token_usage:", total_token_usage)
self._finish_execution(task_output) self._finish_execution(task_output)
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process" # type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
token_usage = task.agent._token_process.get_summary()
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str") # type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
return self._format_output(task_output, token_usage) # TODO: TEST AND FIX
return self._format_output(task_output, total_token_usage)
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]: def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
"""Creates and assigns a manager agent to make sure the crew completes the tasks.""" """Creates and assigns a manager agent to make sure the crew completes the tasks."""
@@ -415,9 +437,11 @@ class Crew(BaseModel):
self._finish_execution(task_output) self._finish_execution(task_output)
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str") # type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
manager_token_usage = manager._token_process.get_summary() manager_token_usage = manager._token_process.get_summary()
return self._format_output( # TODO: TEST AND FIX
task_output, manager_token_usage return (
), manager_token_usage self._format_output(task_output, manager_token_usage),
manager_token_usage,
)
def copy(self): def copy(self):
"""Create a deep copy of the Crew.""" """Create a deep copy of the Crew."""
@@ -432,12 +456,18 @@ class Crew(BaseModel):
"_short_term_memory", "_short_term_memory",
"_long_term_memory", "_long_term_memory",
"_entity_memory", "_entity_memory",
"_telemetry",
"agents", "agents",
"tasks", "tasks",
} }
print("CREW ID", self.id)
print("CURRENT IDS FOR AGENTS", [agent.id for agent in self.agents])
print("CURRENT IDS FOR TASKS", [task.id for task in self.tasks])
# TODO: I think there is a disconnect. We need to pass new agents and tasks to the new crew
cloned_agents = [agent.copy() for agent in self.agents] cloned_agents = [agent.copy() for agent in self.agents]
cloned_tasks = [task.copy() for task in self.tasks] cloned_tasks = [task.copy(cloned_agents) for task in self.tasks]
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
@@ -447,6 +477,12 @@ class Crew(BaseModel):
copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks) copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks)
print("COPIED CREW ID", copied_crew.id)
print("NEW IDS FOR AGENTS", [agent.id for agent in copied_crew.agents])
print("NEW IDS FOR TASKS", [task.id for task in copied_crew.tasks])
# TODO: EXPERIMENT, PRINT ID'S AND MAKE SURE I'M CALLING THE RIGHT AGENTS AND TASKS
return copied_crew return copied_crew
def _set_tasks_callbacks(self) -> None: def _set_tasks_callbacks(self) -> None:
@@ -475,6 +511,7 @@ class Crew(BaseModel):
If full_output is True, then returned data type will be a dictionary else returned outputs are string If full_output is True, then returned data type will be a dictionary else returned outputs are string
""" """
if self.full_output: if self.full_output:
print("SPITTING OUT FULL OUTPUT FOR CREW", self.id)
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str") return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
"final_output": output, "final_output": output,
"tasks_outputs": [task.output for task in self.tasks if task], "tasks_outputs": [task.output for task in self.tasks if task],

View File

@@ -1,8 +1,8 @@
from copy import deepcopy
import os import os
import re import re
import threading import threading
import uuid import uuid
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@@ -192,13 +192,19 @@ class Task(BaseModel):
) )
return result return result
def _execute(self, agent, task, context, tools): def _execute(self, agent: Agent, task, context, tools):
result = agent.execute_task( result = agent.execute_task(
task=task, task=task,
context=context, context=context,
tools=tools, tools=tools,
) )
print("CALLING EXECUTE ON TASK WITH ID", task.id)
print("THIS TASK IS CALLING AGENT", agent.id)
print(
"CALLING TOKEN PROCESS in task on AGENT", agent._token_process.get_summary()
)
exported_output = self._export_output(result) exported_output = self._export_output(result)
self.output = TaskOutput( self.output = TaskOutput(
@@ -246,7 +252,7 @@ class Task(BaseModel):
"""Increment the delegations counter.""" """Increment the delegations counter."""
self.delegations += 1 self.delegations += 1
def copy(self): def copy(self, agents: Optional[List[Agent]] = None) -> "Task":
"""Create a deep copy of the Task.""" """Create a deep copy of the Task."""
exclude = { exclude = {
"id", "id",
@@ -254,6 +260,7 @@ class Task(BaseModel):
"context", "context",
"tools", "tools",
} }
print("ORIGINAL TOOLS:", self.tools)
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
@@ -261,8 +268,16 @@ class Task(BaseModel):
cloned_context = ( cloned_context = (
[task.copy() for task in self.context] if self.context else None [task.copy() for task in self.context] if self.context else None
) )
cloned_agent = self.agent.copy() if self.agent else None
cloned_tools = deepcopy(self.tools) if self.tools else None # TODO: Make sure this clone approach is correct.
def get_agent_by_role(role: str) -> Agent | None:
return next((agent for agent in agents if agent.role == role), None)
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
# cloned_agent = self.agent.copy() if self.agent else None
cloned_tools = deepcopy(self.tools) if self.tools else []
print("CLONED_TOOLS", cloned_tools)
copied_task = Task( copied_task = Task(
**copied_data, **copied_data,
@@ -270,6 +285,8 @@ class Task(BaseModel):
agent=cloned_agent, agent=cloned_agent,
tools=cloned_tools, tools=cloned_tools,
) )
print("TASK TOOLS:", copied_task.tools)
return copied_task return copied_task
def _export_output(self, result: str) -> Any: def _export_output(self, result: str) -> Any:
@@ -328,7 +345,9 @@ class Task(BaseModel):
if self.output_file: if self.output_file:
content = ( content = (
# type: ignore # "str" has no attribute "json" # type: ignore # "str" has no attribute "json"
exported_result if not self.output_pydantic else exported_result.json() exported_result
if not self.output_pydantic
else exported_result.model_dump_json()
) )
self._save_file(content) self._save_file(content)

View File

@@ -1,3 +1,4 @@
import uuid
from typing import Any, Dict, List from typing import Any, Dict, List
import tiktoken import tiktoken
@@ -6,6 +7,7 @@ from langchain.schema import LLMResult
class TokenProcess: class TokenProcess:
id = uuid.uuid4() # TODO: REMOVE THIS
total_tokens: int = 0 total_tokens: int = 0
prompt_tokens: int = 0 prompt_tokens: int = 0
completion_tokens: int = 0 completion_tokens: int = 0
@@ -32,6 +34,7 @@ class TokenProcess:
class TokenCalcHandler(BaseCallbackHandler): class TokenCalcHandler(BaseCallbackHandler):
id = uuid.uuid4() # TODO: REMOVE THIS
model: str = "" model: str = ""
token_cost_process: TokenProcess token_cost_process: TokenProcess

File diff suppressed because it is too large Load Diff

View File

@@ -391,6 +391,83 @@ def test_crew_full_ouput():
"total_tokens": 439, "total_tokens": 439,
}, },
} }
assert False
"""
Issues:
- Each output is not tracking usage metrics
"""
# @pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_for_each_full_ouput():
# TODO: Add docstrings to all tests
from unittest.mock import patch
inputs = [
{"topic": "dog"},
# {"topic": "cat"},
# {"topic": "apple"},
]
expected_outputs = [
"Dogs are loyal companions and popular pets.",
"Cats are independent and low-maintenance pets.",
"Apples are a rich source of dietary fiber and vitamin C.",
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], full_output=True)
results = crew.kickoff_for_each(inputs=inputs)
# with patch.object(Agent, "execute_task") as mock_execute_task:
# mock_execute_task.side_effect = expected_outputs
assert len(results) == len(inputs)
print("RESULTS:", results)
assert False
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_async_kickoff_for_each_full_ouput():
inputs = [
{"topic": "dog"},
{"topic": "cat"},
{"topic": "apple"},
]
expected_outputs = [
"Dogs are loyal companions and popular pets.",
"Cats are independent and low-maintenance pets.",
"Apples are a rich source of dietary fiber and vitamin C.",
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task1, task2], full_output=True)
assert False
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set(): def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
@@ -465,6 +542,266 @@ def test_async_task_execution():
join.assert_called() join.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_kickoff_for_each_single_input():
"""Tests if kickoff_for_each works with a single input."""
from unittest.mock import patch
inputs = [{"topic": "dog"}]
expected_outputs = ["Dogs are loyal companions and popular pets."]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
with patch.object(Agent, "execute_task") as mock_execute_task:
mock_execute_task.side_effect = expected_outputs
crew = Crew(agents=[agent], tasks=[task])
results = crew.kickoff_for_each(inputs=inputs)
assert len(results) == 1
assert results == expected_outputs
@pytest.mark.vcr(filter_headers=["authorization"])
def test_kickoff_for_each_multiple_inputs():
"""Tests if kickoff_for_each works with multiple inputs."""
from unittest.mock import patch
inputs = [
{"topic": "dog"},
{"topic": "cat"},
{"topic": "apple"},
]
expected_outputs = [
"Dogs are loyal companions and popular pets.",
"Cats are independent and low-maintenance pets.",
"Apples are a rich source of dietary fiber and vitamin C.",
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
with patch.object(Agent, "execute_task") as mock_execute_task:
mock_execute_task.side_effect = expected_outputs
crew = Crew(agents=[agent], tasks=[task])
results = crew.kickoff_for_each(inputs=inputs)
assert len(results) == len(inputs)
for i, res in enumerate(results):
assert res == expected_outputs[i]
@pytest.mark.vcr(filter_headers=["authorization"])
def test_kickoff_for_each_empty_input():
"""Tests if kickoff_for_each handles an empty input list."""
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
results = crew.kickoff_for_each(inputs=[])
assert results == []
@pytest.mark.vcr(filter_headers=["authorization"])
def test_kickoff_for_each_invalid_input():
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
with pytest.raises(TypeError):
# Pass a string instead of a list
crew.kickoff_for_each("invalid input")
@pytest.mark.vcr(filter_headers=["authorization"])
def test_kickoff_for_each_error_handling():
"""Tests error handling in kickoff_for_each when kickoff raises an error."""
from unittest.mock import patch
inputs = [
{"topic": "dog"},
{"topic": "cat"},
{"topic": "apple"},
]
expected_outputs = [
"Dogs are loyal companions and popular pets.",
"Cats are independent and low-maintenance pets.",
"Apples are a rich source of dietary fiber and vitamin C.",
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
with patch.object(Crew, "kickoff") as mock_kickoff:
mock_kickoff.side_effect = expected_outputs[:2] + [
Exception("Simulated kickoff error")
]
with pytest.raises(Exception, match="Simulated kickoff error"):
crew.kickoff_for_each(inputs=inputs)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.asyncio
async def test_kickoff_async_basic_functionality_and_output():
"""Tests the basic functionality and output of kickoff_async."""
from unittest.mock import patch
inputs = {"topic": "dog"}
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
# Create the crew
crew = Crew(
agents=[agent],
tasks=[task],
)
expected_output = "This is a sample output from kickoff."
with patch.object(Crew, "kickoff", return_value=expected_output) as mock_kickoff:
result = await crew.kickoff_async(inputs)
assert isinstance(result, str), "Result should be a string"
assert result == expected_output, "Result should match expected output"
mock_kickoff.assert_called_once_with(inputs)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.asyncio # Use pytest-asyncio for async tests
async def test_async_kickoff_for_each_async_basic_functionality_and_output():
"""Tests the basic functionality and output of akickoff_for_each_async."""
from unittest.mock import patch
inputs = [
{"topic": "dog"},
{"topic": "cat"},
{"topic": "apple"},
]
# Define expected outputs for each input
expected_outputs = [
"Dogs are loyal companions and popular pets.",
"Cats are independent and low-maintenance pets.",
"Apples are a rich source of dietary fiber and vitamin C.",
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
with patch.object(
Crew, "kickoff_async", side_effect=expected_outputs
) as mock_kickoff_async:
crew = Crew(agents=[agent], tasks=[task])
results = await crew.kickoff_for_each_async(inputs)
assert len(results) == len(inputs)
assert results == expected_outputs
for input_data in inputs:
mock_kickoff_async.assert_any_call(inputs=input_data)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.asyncio
async def test_async_kickoff_for_each_async_empty_input():
"""Tests if akickoff_for_each_async handles an empty input list."""
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
# Create the crew
crew = Crew(
agents=[agent],
tasks=[task],
)
# Call the function we are testing
results = await crew.kickoff_for_each_async([])
# Assertion
assert results == [], "Result should be an empty list when input is empty"
# TODO: TEST KICKOFF FOR EACH WITH USAGE METRICS
# TODO: TEST ASYNC KICKOFF FOR EACH WITH USAGE METRICS
def test_set_agents_step_callback(): def test_set_agents_step_callback():
from unittest.mock import patch from unittest.mock import patch