mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
WIP: working replay feat fixing inputs, need tests
This commit is contained in:
@@ -81,6 +81,7 @@ class Crew(BaseModel):
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
|
||||
cache: bool = Field(default=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -333,7 +334,9 @@ class Crew(BaseModel):
|
||||
"""Starts the crew to work on its assigned tasks."""
|
||||
self._execution_span = self._telemetry.crew_execution_span(self, inputs)
|
||||
self.execution_logs = []
|
||||
self._task_output_handler.reset()
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
# self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
@@ -359,7 +362,7 @@ class Crew(BaseModel):
|
||||
metrics = []
|
||||
|
||||
if self.process == Process.sequential:
|
||||
result = self._run_sequential_process(inputs)
|
||||
result = self._run_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result, manager_metrics = self._run_hierarchical_process() # type: ignore # Incompatible types in assignment (expression has type "str | dict[str, Any]", variable has type "str")
|
||||
metrics.append(manager_metrics)
|
||||
@@ -452,7 +455,11 @@ class Crew(BaseModel):
|
||||
|
||||
return results
|
||||
|
||||
def _store_execution_log(self, task, output, task_index, inputs=None):
|
||||
def _store_execution_log(self, task: Task, output, task_index):
|
||||
if self._inputs:
|
||||
inputs = self._inputs
|
||||
else:
|
||||
inputs = {}
|
||||
log = {
|
||||
"task_id": str(task.id),
|
||||
"description": task.description,
|
||||
@@ -468,19 +475,15 @@ class Crew(BaseModel):
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"task_index": task_index,
|
||||
# "output_py": output.pydantic_output,
|
||||
"inputs": inputs,
|
||||
# "task": task.model_dump(),
|
||||
}
|
||||
self.execution_logs.append(log)
|
||||
self._task_output_handler.append(log)
|
||||
|
||||
def _run_sequential_process(
|
||||
self, inputs: Dict[str, Any] | None = None
|
||||
) -> CrewOutput:
|
||||
def _run_sequential_process(self) -> CrewOutput:
|
||||
"""Executes tasks sequentially and returns the final output."""
|
||||
self.execution_logs = []
|
||||
task_outputs = self._execute_tasks(self.tasks, inputs=inputs)
|
||||
task_outputs = self._execute_tasks(self.tasks)
|
||||
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
self._finish_execution(final_string_output)
|
||||
self.save_execution_logs()
|
||||
@@ -493,10 +496,9 @@ class Crew(BaseModel):
|
||||
tasks,
|
||||
start_index=0,
|
||||
is_replay=False,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
):
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput]]] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
for task_index, task in enumerate(tasks[start_index:], start=start_index):
|
||||
if task.agent and task.agent.allow_delegation:
|
||||
agents_for_delegation = [
|
||||
@@ -527,12 +529,10 @@ class Crew(BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=task.agent, context=context, tools=task.tools
|
||||
)
|
||||
futures.append((task, future))
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(
|
||||
futures, task_index, inputs
|
||||
)
|
||||
task_outputs = self._process_async_tasks(futures)
|
||||
futures.clear()
|
||||
|
||||
context = aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
@@ -541,10 +541,10 @@ class Crew(BaseModel):
|
||||
)
|
||||
task_outputs = [task_output]
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, inputs)
|
||||
self._store_execution_log(task, task_output, task_index)
|
||||
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, len(tasks), inputs)
|
||||
task_outputs = self._process_async_tasks(futures)
|
||||
|
||||
return task_outputs
|
||||
|
||||
@@ -556,20 +556,35 @@ class Crew(BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: List[Tuple[Task, Future[TaskOutput]]],
|
||||
task_index: int,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs = []
|
||||
for future_task, future in futures:
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
self._store_execution_log(future_task, task_output, task_index, inputs)
|
||||
self._store_execution_log(future_task, task_output, task_index)
|
||||
|
||||
return task_outputs
|
||||
|
||||
def replay_from_task(self, task_id: str):
|
||||
def _get_agent(self, role: str) -> Optional[BaseAgent]:
|
||||
"""
|
||||
Private method to get an agent by role.
|
||||
|
||||
Args:
|
||||
role (str): The role of the agent to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[BaseAgent]: The agent with the specified role, or None if not found.
|
||||
"""
|
||||
for agent in self.agents:
|
||||
if agent.role == role:
|
||||
return agent
|
||||
return None
|
||||
|
||||
def replay_from_task(self, task_id: str, inputs: Dict[str, Any] | None = None):
|
||||
all_tasks = self.tasks.copy()
|
||||
|
||||
stored_outputs = self._load_stored_outputs()
|
||||
start_index = next(
|
||||
(
|
||||
@@ -579,9 +594,20 @@ class Crew(BaseModel):
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Generate tasks based on what was previously replayed
|
||||
if len(self.tasks) != len(stored_outputs):
|
||||
for output in stored_outputs[start_index:]:
|
||||
matching_index = output["task_index"]
|
||||
matching_task = self.tasks[matching_index]
|
||||
if matching_task:
|
||||
new_task = matching_task.copy()
|
||||
new_task.agent = self._get_agent(output["agent_role"])
|
||||
all_tasks.append(new_task)
|
||||
|
||||
if start_index is None:
|
||||
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
|
||||
# Create a map of task ID to stored output
|
||||
|
||||
# this handles passing the correct context along and updating following task executions with the new task_ouputs as context
|
||||
stored_output_map: Dict[str, dict] = {
|
||||
log["task_id"]: log["output"] for log in stored_outputs
|
||||
}
|
||||
@@ -589,13 +615,21 @@ class Crew(BaseModel):
|
||||
task_outputs: List[
|
||||
TaskOutput
|
||||
] = [] # will propogate the old outputs first to add context then fill the content with the new task outputs relative to the replay start
|
||||
futures: List[Tuple[Task, Future[TaskOutput]]] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
context = ""
|
||||
|
||||
inputs = stored_outputs[start_index].get("inputs", {})
|
||||
if inputs is not None:
|
||||
self._interpolate_inputs(inputs)
|
||||
for task_index, task in enumerate(self.tasks):
|
||||
# inputs can be overrided with new passed inputs
|
||||
replay_inputs = (
|
||||
inputs
|
||||
if inputs is not None
|
||||
else stored_outputs[start_index].get("inputs", {})
|
||||
)
|
||||
|
||||
self._inputs = replay_inputs # overriding
|
||||
if replay_inputs:
|
||||
self._interpolate_inputs(replay_inputs)
|
||||
|
||||
for task_index, task in enumerate(all_tasks):
|
||||
if task_index < start_index:
|
||||
# Use stored output for tasks before the replay point
|
||||
if task.id in stored_output_map:
|
||||
@@ -632,12 +666,10 @@ class Crew(BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=task.agent, context=context, tools=task.tools
|
||||
)
|
||||
futures.append((task, future))
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
if futures:
|
||||
async_outputs = self._process_async_tasks(
|
||||
futures, task_index, inputs
|
||||
)
|
||||
async_outputs = self._process_async_tasks(futures)
|
||||
task_outputs.extend(async_outputs)
|
||||
for output in async_outputs:
|
||||
context += (
|
||||
@@ -649,14 +681,14 @@ class Crew(BaseModel):
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, inputs)
|
||||
self._store_execution_log(task, task_output, task_index)
|
||||
context += (
|
||||
f"\nTask {task_index + 1} Output:\n{task_output.raw_output}"
|
||||
)
|
||||
|
||||
# Process any remaining async tasks
|
||||
if futures:
|
||||
async_outputs = self._process_async_tasks(futures, len(self.tasks), inputs)
|
||||
async_outputs = self._process_async_tasks(futures)
|
||||
task_outputs.extend(async_outputs)
|
||||
# Calculate usage metrics
|
||||
token_usage = self.calculate_usage_metrics()
|
||||
|
||||
@@ -306,6 +306,8 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
|
||||
if agents is None:
|
||||
return None
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
|
||||
@@ -6,12 +6,26 @@ from pydantic import BaseModel
|
||||
|
||||
class CrewJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
return str(obj)
|
||||
return self._handle_pydantic_model(obj)
|
||||
elif isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
def _handle_pydantic_model(self, obj):
|
||||
try:
|
||||
data = obj.model_dump()
|
||||
# Remove circular references
|
||||
for key, value in data.items():
|
||||
if isinstance(value, BaseModel):
|
||||
data[key] = str(
|
||||
value
|
||||
) # Convert nested models to string representation
|
||||
return data
|
||||
except RecursionError:
|
||||
return str(
|
||||
obj
|
||||
) # Fall back to string representation if circular reference is detected
|
||||
|
||||
@@ -97,6 +97,11 @@ class TaskOutputJsonHandler:
|
||||
json.dump(file_data, file, indent=2, cls=CrewJSONEncoder)
|
||||
file.truncate()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the JSON file by creating an empty file."""
|
||||
with open(self.file_path, "w") as f:
|
||||
json.dump([], f)
|
||||
|
||||
def load(self) -> list:
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user