Address PR feedback: Improve error handling, add OpenAPI docs, and verify task management

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-09 04:58:19 +00:00
parent cfabb9fa78
commit 9bb8854c25
8 changed files with 482 additions and 57 deletions

View File

@@ -8,9 +8,12 @@ import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
from uuid import uuid4
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
Artifact,
Message,
@@ -165,12 +168,37 @@ class TaskManager(ABC):
class InMemoryTaskManager(TaskManager):
"""In-memory implementation of the A2A task manager."""
def __init__(self):
"""Initialize the in-memory task manager."""
def __init__(
self,
task_ttl: Optional[int] = None,
cleanup_interval: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the in-memory task manager.
Args:
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
config: The A2A configuration. If provided, other parameters are ignored.
"""
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
self._tasks: Dict[str, Task] = {}
self._push_notifications: Dict[str, PushNotificationConfig] = {}
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
self._task_timestamps: Dict[str, datetime] = {}
self._logger = logging.getLogger(__name__)
self._cleanup_task = None
try:
if asyncio.get_running_loop():
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
except RuntimeError:
self._logger.info("No running event loop, periodic cleanup disabled")
async def create_task(
self,
@@ -198,6 +226,7 @@ class InMemoryTaskManager(TaskManager):
state=TaskState.SUBMITTED,
message=message,
timestamp=datetime.now(),
previous_state=None, # Initial state has no previous state
)
task = Task(
@@ -211,6 +240,7 @@ class InMemoryTaskManager(TaskManager):
self._tasks[task_id] = task
self._task_subscribers[task_id] = set()
self._task_timestamps[task_id] = datetime.now()
return task
async def get_task(
@@ -263,20 +293,29 @@ class InMemoryTaskManager(TaskManager):
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task = self._tasks[task_id]
previous_state = task.status.state if task.status else None
if previous_state and not TaskState.is_valid_transition(previous_state, state):
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
status = TaskStatus(
state=state,
message=message,
timestamp=datetime.now(),
previous_state=previous_state,
)
task.status = status
if message and task.history is not None:
task.history.append(message)
self._task_timestamps[task_id] = datetime.now()
event = TaskStatusUpdateEvent(
id=task_id,
status=status,
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED],
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
metadata=metadata or {},
)
@@ -436,3 +475,48 @@ class InMemoryTaskManager(TaskManager):
if task_id in self._task_subscribers:
for queue in self._task_subscribers[task_id]:
await queue.put(event)
async def _periodic_cleanup(self) -> None:
"""Periodically clean up expired tasks."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired_tasks()
except asyncio.CancelledError:
break
except Exception as e:
self._logger.exception(f"Error during periodic cleanup: {e}")
async def _cleanup_expired_tasks(self) -> None:
"""Clean up expired tasks."""
now = datetime.now()
expired_tasks = []
for task_id, timestamp in self._task_timestamps.items():
if (now - timestamp).total_seconds() > self._task_ttl:
expired_tasks.append(task_id)
for task_id in expired_tasks:
self._logger.info(f"Cleaning up expired task: {task_id}")
self._tasks.pop(task_id, None)
self._push_notifications.pop(task_id, None)
self._task_timestamps.pop(task_id, None)
if task_id in self._task_subscribers:
previous_state = None
if task_id in self._tasks and self._tasks[task_id].status:
previous_state = self._tasks[task_id].status.state
status = TaskStatus(
state=TaskState.EXPIRED,
timestamp=now,
previous_state=previous_state,
)
event = TaskStatusUpdateEvent(
task_id=task_id,
status=status,
final=True,
)
await self._notify_subscribers(task_id, event)
self._task_subscribers.pop(task_id, None)