Files
crewAI/src/crewai/a2a/task_manager.py

523 lines
15 KiB
Python

"""
A2A protocol task manager for CrewAI.
This module implements the task manager for the A2A protocol in CrewAI.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
from uuid import uuid4
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
Artifact,
Message,
PushNotificationConfig,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
class TaskManager(ABC):
"""Abstract base class for A2A task managers."""
@abstractmethod
async def create_task(
self,
task_id: str,
session_id: Optional[str] = None,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Task:
"""Create a new task.
Args:
task_id: The ID of the task.
session_id: The session ID.
message: The initial message.
metadata: Additional metadata.
Returns:
The created task.
"""
pass
@abstractmethod
async def get_task(
self, task_id: str, history_length: Optional[int] = None
) -> Task:
"""Get a task by ID.
Args:
task_id: The ID of the task.
history_length: The number of messages to include in the history.
Returns:
The task.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def update_task_status(
self,
task_id: str,
state: TaskState,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskStatusUpdateEvent:
"""Update the status of a task.
Args:
task_id: The ID of the task.
state: The new state of the task.
message: An optional message to include with the status update.
metadata: Additional metadata.
Returns:
The task status update event.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def add_task_artifact(
self,
task_id: str,
artifact: Artifact,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskArtifactUpdateEvent:
"""Add an artifact to a task.
Args:
task_id: The ID of the task.
artifact: The artifact to add.
metadata: Additional metadata.
Returns:
The task artifact update event.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def cancel_task(self, task_id: str) -> Task:
"""Cancel a task.
Args:
task_id: The ID of the task.
Returns:
The canceled task.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def set_push_notification(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfig:
"""Set push notification for a task.
Args:
task_id: The ID of the task.
config: The push notification configuration.
Returns:
The push notification configuration.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def get_push_notification(
self, task_id: str
) -> Optional[PushNotificationConfig]:
"""Get push notification for a task.
Args:
task_id: The ID of the task.
Returns:
The push notification configuration, or None if not set.
Raises:
KeyError: If the task is not found.
"""
pass
class InMemoryTaskManager(TaskManager):
"""In-memory implementation of the A2A task manager."""
def __init__(
self,
task_ttl: Optional[int] = None,
cleanup_interval: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the in-memory task manager.
Args:
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
config: The A2A configuration. If provided, other parameters are ignored.
"""
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
self._tasks: Dict[str, Task] = {}
self._push_notifications: Dict[str, PushNotificationConfig] = {}
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
self._task_timestamps: Dict[str, datetime] = {}
self._logger = logging.getLogger(__name__)
self._cleanup_task = None
try:
if asyncio.get_running_loop():
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
except RuntimeError:
self._logger.info("No running event loop, periodic cleanup disabled")
async def create_task(
self,
task_id: str,
session_id: Optional[str] = None,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Task:
"""Create a new task.
Args:
task_id: The ID of the task.
session_id: The session ID.
message: The initial message.
metadata: Additional metadata.
Returns:
The created task.
"""
if task_id in self._tasks:
return self._tasks[task_id]
session_id = session_id or uuid4().hex
status = TaskStatus(
state=TaskState.SUBMITTED,
message=message,
timestamp=datetime.now(),
previous_state=None, # Initial state has no previous state
)
task = Task(
id=task_id,
sessionId=session_id,
status=status,
artifacts=[],
history=[message] if message else [],
metadata=metadata or {},
)
self._tasks[task_id] = task
self._task_subscribers[task_id] = set()
self._task_timestamps[task_id] = datetime.now()
return task
async def get_task(
self, task_id: str, history_length: Optional[int] = None
) -> Task:
"""Get a task by ID.
Args:
task_id: The ID of the task.
history_length: The number of messages to include in the history.
Returns:
The task.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if history_length is not None and task.history:
task_copy = task.model_copy(deep=True)
task_copy.history = task.history[-history_length:]
return task_copy
return task
async def update_task_status(
self,
task_id: str,
state: TaskState,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskStatusUpdateEvent:
"""Update the status of a task.
Args:
task_id: The ID of the task.
state: The new state of the task.
message: An optional message to include with the status update.
metadata: Additional metadata.
Returns:
The task status update event.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task = self._tasks[task_id]
previous_state = task.status.state if task.status else None
if previous_state and not TaskState.is_valid_transition(previous_state, state):
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
status = TaskStatus(
state=state,
message=message,
timestamp=datetime.now(),
previous_state=previous_state,
)
task.status = status
if message and task.history is not None:
task.history.append(message)
self._task_timestamps[task_id] = datetime.now()
event = TaskStatusUpdateEvent(
id=task_id,
status=status,
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
metadata=metadata or {},
)
await self._notify_subscribers(task_id, event)
return event
async def add_task_artifact(
self,
task_id: str,
artifact: Artifact,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskArtifactUpdateEvent:
"""Add an artifact to a task.
Args:
task_id: The ID of the task.
artifact: The artifact to add.
metadata: Additional metadata.
Returns:
The task artifact update event.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if task.artifacts is None:
task.artifacts = []
if artifact.append and task.artifacts:
for existing in task.artifacts:
if existing.name == artifact.name:
existing.parts.extend(artifact.parts)
existing.lastChunk = artifact.lastChunk
break
else:
task.artifacts.append(artifact)
else:
task.artifacts.append(artifact)
event = TaskArtifactUpdateEvent(
id=task_id,
artifact=artifact,
metadata=metadata or {},
)
await self._notify_subscribers(task_id, event)
return event
async def cancel_task(self, task_id: str) -> Task:
"""Cancel a task.
Args:
task_id: The ID of the task.
Returns:
The canceled task.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
await self.update_task_status(task_id, TaskState.CANCELED)
return task
async def set_push_notification(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfig:
"""Set push notification for a task.
Args:
task_id: The ID of the task.
config: The push notification configuration.
Returns:
The push notification configuration.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
self._push_notifications[task_id] = config
return config
async def get_push_notification(
self, task_id: str
) -> Optional[PushNotificationConfig]:
"""Get push notification for a task.
Args:
task_id: The ID of the task.
Returns:
The push notification configuration, or None if not set.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
return self._push_notifications.get(task_id)
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
"""Subscribe to task updates.
Args:
task_id: The ID of the task.
Returns:
A queue that will receive task updates.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
queue: asyncio.Queue = asyncio.Queue()
self._task_subscribers.setdefault(task_id, set()).add(queue)
return queue
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
"""Unsubscribe from task updates.
Args:
task_id: The ID of the task.
queue: The queue to unsubscribe.
"""
if task_id in self._task_subscribers:
self._task_subscribers[task_id].discard(queue)
async def _notify_subscribers(
self,
task_id: str,
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
) -> None:
"""Notify subscribers of a task update.
Args:
task_id: The ID of the task.
event: The event to send to subscribers.
"""
if task_id in self._task_subscribers:
for queue in self._task_subscribers[task_id]:
await queue.put(event)
async def _periodic_cleanup(self) -> None:
"""Periodically clean up expired tasks."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired_tasks()
except asyncio.CancelledError:
break
except Exception as e:
self._logger.exception(f"Error during periodic cleanup: {e}")
async def _cleanup_expired_tasks(self) -> None:
"""Clean up expired tasks."""
now = datetime.now()
expired_tasks = []
for task_id, timestamp in self._task_timestamps.items():
if (now - timestamp).total_seconds() > self._task_ttl:
expired_tasks.append(task_id)
for task_id in expired_tasks:
self._logger.info(f"Cleaning up expired task: {task_id}")
self._tasks.pop(task_id, None)
self._push_notifications.pop(task_id, None)
self._task_timestamps.pop(task_id, None)
if task_id in self._task_subscribers:
previous_state = None
if task_id in self._tasks and self._tasks[task_id].status:
previous_state = self._tasks[task_id].status.state
status = TaskStatus(
state=TaskState.EXPIRED,
timestamp=now,
previous_state=previous_state,
)
event = TaskStatusUpdateEvent(
task_id=task_id,
status=status,
final=True,
)
await self._notify_subscribers(task_id, event)
self._task_subscribers.pop(task_id, None)