mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
Address PR feedback: Improve error handling, add OpenAPI docs, and verify task management
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -8,9 +8,12 @@ import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.types.a2a import (
|
||||
Artifact,
|
||||
Message,
|
||||
@@ -165,12 +168,37 @@ class TaskManager(ABC):
|
||||
class InMemoryTaskManager(TaskManager):
|
||||
"""In-memory implementation of the A2A task manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the in-memory task manager."""
|
||||
def __init__(
|
||||
self,
|
||||
task_ttl: Optional[int] = None,
|
||||
cleanup_interval: Optional[int] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the in-memory task manager.
|
||||
|
||||
Args:
|
||||
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
|
||||
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config or A2AConfig.from_env()
|
||||
|
||||
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
|
||||
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
|
||||
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
||||
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
||||
self._task_timestamps: Dict[str, datetime] = {}
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._cleanup_task = None
|
||||
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
except RuntimeError:
|
||||
self._logger.info("No running event loop, periodic cleanup disabled")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
@@ -198,6 +226,7 @@ class InMemoryTaskManager(TaskManager):
|
||||
state=TaskState.SUBMITTED,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=None, # Initial state has no previous state
|
||||
)
|
||||
|
||||
task = Task(
|
||||
@@ -211,6 +240,7 @@ class InMemoryTaskManager(TaskManager):
|
||||
|
||||
self._tasks[task_id] = task
|
||||
self._task_subscribers[task_id] = set()
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
return task
|
||||
|
||||
async def get_task(
|
||||
@@ -263,20 +293,29 @@ class InMemoryTaskManager(TaskManager):
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
task = self._tasks[task_id]
|
||||
previous_state = task.status.state if task.status else None
|
||||
|
||||
if previous_state and not TaskState.is_valid_transition(previous_state, state):
|
||||
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
|
||||
|
||||
status = TaskStatus(
|
||||
state=state,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=previous_state,
|
||||
)
|
||||
task.status = status
|
||||
|
||||
if message and task.history is not None:
|
||||
task.history.append(message)
|
||||
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
|
||||
event = TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=status,
|
||||
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED],
|
||||
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
@@ -436,3 +475,48 @@ class InMemoryTaskManager(TaskManager):
|
||||
if task_id in self._task_subscribers:
|
||||
for queue in self._task_subscribers[task_id]:
|
||||
await queue.put(event)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Periodically clean up expired tasks."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired_tasks()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error during periodic cleanup: {e}")
|
||||
|
||||
async def _cleanup_expired_tasks(self) -> None:
|
||||
"""Clean up expired tasks."""
|
||||
now = datetime.now()
|
||||
expired_tasks = []
|
||||
|
||||
for task_id, timestamp in self._task_timestamps.items():
|
||||
if (now - timestamp).total_seconds() > self._task_ttl:
|
||||
expired_tasks.append(task_id)
|
||||
|
||||
for task_id in expired_tasks:
|
||||
self._logger.info(f"Cleaning up expired task: {task_id}")
|
||||
self._tasks.pop(task_id, None)
|
||||
self._push_notifications.pop(task_id, None)
|
||||
self._task_timestamps.pop(task_id, None)
|
||||
|
||||
if task_id in self._task_subscribers:
|
||||
previous_state = None
|
||||
if task_id in self._tasks and self._tasks[task_id].status:
|
||||
previous_state = self._tasks[task_id].status.state
|
||||
|
||||
status = TaskStatus(
|
||||
state=TaskState.EXPIRED,
|
||||
timestamp=now,
|
||||
previous_state=previous_state,
|
||||
)
|
||||
event = TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
final=True,
|
||||
)
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
self._task_subscribers.pop(task_id, None)
|
||||
|
||||
Reference in New Issue
Block a user