mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
523 lines
15 KiB
Python
523 lines
15 KiB
Python
"""
|
|
A2A protocol task manager for CrewAI.
|
|
|
|
This module implements the task manager for the A2A protocol in CrewAI.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
|
from uuid import uuid4
|
|
|
|
if TYPE_CHECKING:
|
|
from crewai.a2a.config import A2AConfig
|
|
|
|
from crewai.types.a2a import (
|
|
Artifact,
|
|
Message,
|
|
PushNotificationConfig,
|
|
Task,
|
|
TaskArtifactUpdateEvent,
|
|
TaskState,
|
|
TaskStatus,
|
|
TaskStatusUpdateEvent,
|
|
)
|
|
|
|
|
|
class TaskManager(ABC):
|
|
"""Abstract base class for A2A task managers."""
|
|
|
|
@abstractmethod
|
|
async def create_task(
|
|
self,
|
|
task_id: str,
|
|
session_id: Optional[str] = None,
|
|
message: Optional[Message] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Task:
|
|
"""Create a new task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
session_id: The session ID.
|
|
message: The initial message.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The created task.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_task(
|
|
self, task_id: str, history_length: Optional[int] = None
|
|
) -> Task:
|
|
"""Get a task by ID.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
history_length: The number of messages to include in the history.
|
|
|
|
Returns:
|
|
The task.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_task_status(
|
|
self,
|
|
task_id: str,
|
|
state: TaskState,
|
|
message: Optional[Message] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> TaskStatusUpdateEvent:
|
|
"""Update the status of a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
state: The new state of the task.
|
|
message: An optional message to include with the status update.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The task status update event.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def add_task_artifact(
|
|
self,
|
|
task_id: str,
|
|
artifact: Artifact,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> TaskArtifactUpdateEvent:
|
|
"""Add an artifact to a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
artifact: The artifact to add.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The task artifact update event.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def cancel_task(self, task_id: str) -> Task:
|
|
"""Cancel a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
|
|
Returns:
|
|
The canceled task.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def set_push_notification(
|
|
self, task_id: str, config: PushNotificationConfig
|
|
) -> PushNotificationConfig:
|
|
"""Set push notification for a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
config: The push notification configuration.
|
|
|
|
Returns:
|
|
The push notification configuration.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_push_notification(
|
|
self, task_id: str
|
|
) -> Optional[PushNotificationConfig]:
|
|
"""Get push notification for a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
|
|
Returns:
|
|
The push notification configuration, or None if not set.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
pass
|
|
|
|
|
|
class InMemoryTaskManager(TaskManager):
|
|
"""In-memory implementation of the A2A task manager."""
|
|
|
|
def __init__(
|
|
self,
|
|
task_ttl: Optional[int] = None,
|
|
cleanup_interval: Optional[int] = None,
|
|
config: Optional["A2AConfig"] = None,
|
|
):
|
|
"""Initialize the in-memory task manager.
|
|
|
|
Args:
|
|
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
|
|
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
|
|
config: The A2A configuration. If provided, other parameters are ignored.
|
|
"""
|
|
from crewai.a2a.config import A2AConfig
|
|
self.config = config or A2AConfig.from_env()
|
|
|
|
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
|
|
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
|
|
|
|
self._tasks: Dict[str, Task] = {}
|
|
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
|
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
|
self._task_timestamps: Dict[str, datetime] = {}
|
|
self._logger = logging.getLogger(__name__)
|
|
self._cleanup_task = None
|
|
|
|
try:
|
|
if asyncio.get_running_loop():
|
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
|
except RuntimeError:
|
|
self._logger.info("No running event loop, periodic cleanup disabled")
|
|
|
|
async def create_task(
|
|
self,
|
|
task_id: str,
|
|
session_id: Optional[str] = None,
|
|
message: Optional[Message] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Task:
|
|
"""Create a new task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
session_id: The session ID.
|
|
message: The initial message.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The created task.
|
|
"""
|
|
if task_id in self._tasks:
|
|
return self._tasks[task_id]
|
|
|
|
session_id = session_id or uuid4().hex
|
|
status = TaskStatus(
|
|
state=TaskState.SUBMITTED,
|
|
message=message,
|
|
timestamp=datetime.now(),
|
|
previous_state=None, # Initial state has no previous state
|
|
)
|
|
|
|
task = Task(
|
|
id=task_id,
|
|
sessionId=session_id,
|
|
status=status,
|
|
artifacts=[],
|
|
history=[message] if message else [],
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
self._tasks[task_id] = task
|
|
self._task_subscribers[task_id] = set()
|
|
self._task_timestamps[task_id] = datetime.now()
|
|
return task
|
|
|
|
async def get_task(
|
|
self, task_id: str, history_length: Optional[int] = None
|
|
) -> Task:
|
|
"""Get a task by ID.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
history_length: The number of messages to include in the history.
|
|
|
|
Returns:
|
|
The task.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
task = self._tasks[task_id]
|
|
if history_length is not None and task.history:
|
|
task_copy = task.model_copy(deep=True)
|
|
task_copy.history = task.history[-history_length:]
|
|
return task_copy
|
|
return task
|
|
|
|
async def update_task_status(
|
|
self,
|
|
task_id: str,
|
|
state: TaskState,
|
|
message: Optional[Message] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> TaskStatusUpdateEvent:
|
|
"""Update the status of a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
state: The new state of the task.
|
|
message: An optional message to include with the status update.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The task status update event.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
task = self._tasks[task_id]
|
|
task = self._tasks[task_id]
|
|
previous_state = task.status.state if task.status else None
|
|
|
|
if previous_state and not TaskState.is_valid_transition(previous_state, state):
|
|
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
|
|
|
|
status = TaskStatus(
|
|
state=state,
|
|
message=message,
|
|
timestamp=datetime.now(),
|
|
previous_state=previous_state,
|
|
)
|
|
task.status = status
|
|
|
|
if message and task.history is not None:
|
|
task.history.append(message)
|
|
|
|
self._task_timestamps[task_id] = datetime.now()
|
|
|
|
event = TaskStatusUpdateEvent(
|
|
id=task_id,
|
|
status=status,
|
|
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
await self._notify_subscribers(task_id, event)
|
|
|
|
return event
|
|
|
|
async def add_task_artifact(
|
|
self,
|
|
task_id: str,
|
|
artifact: Artifact,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> TaskArtifactUpdateEvent:
|
|
"""Add an artifact to a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
artifact: The artifact to add.
|
|
metadata: Additional metadata.
|
|
|
|
Returns:
|
|
The task artifact update event.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
task = self._tasks[task_id]
|
|
if task.artifacts is None:
|
|
task.artifacts = []
|
|
|
|
if artifact.append and task.artifacts:
|
|
for existing in task.artifacts:
|
|
if existing.name == artifact.name:
|
|
existing.parts.extend(artifact.parts)
|
|
existing.lastChunk = artifact.lastChunk
|
|
break
|
|
else:
|
|
task.artifacts.append(artifact)
|
|
else:
|
|
task.artifacts.append(artifact)
|
|
|
|
event = TaskArtifactUpdateEvent(
|
|
id=task_id,
|
|
artifact=artifact,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
await self._notify_subscribers(task_id, event)
|
|
|
|
return event
|
|
|
|
async def cancel_task(self, task_id: str) -> Task:
|
|
"""Cancel a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
|
|
Returns:
|
|
The canceled task.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
task = self._tasks[task_id]
|
|
|
|
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
|
|
await self.update_task_status(task_id, TaskState.CANCELED)
|
|
|
|
return task
|
|
|
|
async def set_push_notification(
|
|
self, task_id: str, config: PushNotificationConfig
|
|
) -> PushNotificationConfig:
|
|
"""Set push notification for a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
config: The push notification configuration.
|
|
|
|
Returns:
|
|
The push notification configuration.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
self._push_notifications[task_id] = config
|
|
return config
|
|
|
|
async def get_push_notification(
|
|
self, task_id: str
|
|
) -> Optional[PushNotificationConfig]:
|
|
"""Get push notification for a task.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
|
|
Returns:
|
|
The push notification configuration, or None if not set.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
return self._push_notifications.get(task_id)
|
|
|
|
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
|
|
"""Subscribe to task updates.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
|
|
Returns:
|
|
A queue that will receive task updates.
|
|
|
|
Raises:
|
|
KeyError: If the task is not found.
|
|
"""
|
|
if task_id not in self._tasks:
|
|
raise KeyError(f"Task {task_id} not found")
|
|
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
self._task_subscribers.setdefault(task_id, set()).add(queue)
|
|
return queue
|
|
|
|
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
|
|
"""Unsubscribe from task updates.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
queue: The queue to unsubscribe.
|
|
"""
|
|
if task_id in self._task_subscribers:
|
|
self._task_subscribers[task_id].discard(queue)
|
|
|
|
async def _notify_subscribers(
|
|
self,
|
|
task_id: str,
|
|
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
|
|
) -> None:
|
|
"""Notify subscribers of a task update.
|
|
|
|
Args:
|
|
task_id: The ID of the task.
|
|
event: The event to send to subscribers.
|
|
"""
|
|
if task_id in self._task_subscribers:
|
|
for queue in self._task_subscribers[task_id]:
|
|
await queue.put(event)
|
|
|
|
async def _periodic_cleanup(self) -> None:
|
|
"""Periodically clean up expired tasks."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self._cleanup_interval)
|
|
await self._cleanup_expired_tasks()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
self._logger.exception(f"Error during periodic cleanup: {e}")
|
|
|
|
async def _cleanup_expired_tasks(self) -> None:
|
|
"""Clean up expired tasks."""
|
|
now = datetime.now()
|
|
expired_tasks = []
|
|
|
|
for task_id, timestamp in self._task_timestamps.items():
|
|
if (now - timestamp).total_seconds() > self._task_ttl:
|
|
expired_tasks.append(task_id)
|
|
|
|
for task_id in expired_tasks:
|
|
self._logger.info(f"Cleaning up expired task: {task_id}")
|
|
self._tasks.pop(task_id, None)
|
|
self._push_notifications.pop(task_id, None)
|
|
self._task_timestamps.pop(task_id, None)
|
|
|
|
if task_id in self._task_subscribers:
|
|
previous_state = None
|
|
if task_id in self._tasks and self._tasks[task_id].status:
|
|
previous_state = self._tasks[task_id].status.state
|
|
|
|
status = TaskStatus(
|
|
state=TaskState.EXPIRED,
|
|
timestamp=now,
|
|
previous_state=previous_state,
|
|
)
|
|
event = TaskStatusUpdateEvent(
|
|
task_id=task_id,
|
|
status=status,
|
|
final=True,
|
|
)
|
|
await self._notify_subscribers(task_id, event)
|
|
|
|
self._task_subscribers.pop(task_id, None)
|