mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-01 17:28:14 +00:00
Compare commits
2 Commits
memory_pat
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bb8854c25 | ||
|
|
cfabb9fa78 |
59
examples/a2a_protocol_example.py
Normal file
59
examples/a2a_protocol_example.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Example of using the A2A protocol with CrewAI.
|
||||
|
||||
This example demonstrates how to:
|
||||
1. Create an agent with A2A protocol support
|
||||
2. Start an A2A server for the agent
|
||||
3. Execute a task via the A2A protocol
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uvicorn
|
||||
from threading import Thread
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a import A2AServer, InMemoryTaskManager
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze data and provide insights",
|
||||
backstory="I am a data analyst with expertise in finding patterns and insights in data.",
|
||||
a2a_enabled=True,
|
||||
a2a_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the A2A server."""
|
||||
task_manager = InMemoryTaskManager()
|
||||
|
||||
server = A2AServer(task_manager=task_manager)
|
||||
|
||||
uvicorn.run(server.app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
async def execute_task_via_a2a():
|
||||
"""Execute a task via the A2A protocol."""
|
||||
await asyncio.sleep(2)
|
||||
|
||||
result = await agent.execute_task_via_a2a(
|
||||
task_description="Analyze the following data and provide insights: [1, 2, 3, 4, 5]",
|
||||
context="This is a simple example of using the A2A protocol.",
|
||||
)
|
||||
|
||||
print(f"Task result: {result}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the example."""
|
||||
server_thread = Thread(target=start_server)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
await execute_task_via_a2a()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -23,4 +23,9 @@ __all__ = [
|
||||
"LLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"A2AAgentIntegration",
|
||||
"A2AClient",
|
||||
"A2AServer",
|
||||
]
|
||||
|
||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer
|
||||
|
||||
16
src/crewai/a2a/__init__.py
Normal file
16
src/crewai/a2a/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""A2A protocol implementation for CrewAI."""
|
||||
|
||||
from crewai.a2a.agent import A2AAgentIntegration
|
||||
from crewai.a2a.client import A2AClient
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.server import A2AServer
|
||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||
|
||||
__all__ = [
|
||||
"A2AAgentIntegration",
|
||||
"A2AClient",
|
||||
"A2AServer",
|
||||
"TaskManager",
|
||||
"InMemoryTaskManager",
|
||||
"A2AConfig",
|
||||
]
|
||||
223
src/crewai/a2a/agent.py
Normal file
223
src/crewai/a2a/agent.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
A2A protocol agent integration for CrewAI.
|
||||
|
||||
This module implements the integration between CrewAI agents and the A2A protocol.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.a2a.client import A2AClient
|
||||
from crewai.a2a.task_manager import TaskManager
|
||||
from crewai.types.a2a import (
|
||||
Artifact,
|
||||
DataPart,
|
||||
FilePart,
|
||||
Message,
|
||||
Part,
|
||||
Task as A2ATask,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
|
||||
class A2AAgentIntegration:
|
||||
"""Integration between CrewAI agents and the A2A protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_manager: Optional[TaskManager] = None,
|
||||
client: Optional[A2AClient] = None,
|
||||
):
|
||||
"""Initialize the A2A agent integration.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to use for handling A2A tasks.
|
||||
client: The A2A client to use for sending tasks to other agents.
|
||||
"""
|
||||
self.task_manager = task_manager
|
||||
self.client = client
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def execute_task_via_a2a(
|
||||
self,
|
||||
agent_url: str,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> str:
|
||||
"""Execute a task via the A2A protocol.
|
||||
|
||||
Args:
|
||||
agent_url: The URL of the agent to execute the task.
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
if not self.client:
|
||||
self.client = A2AClient(base_url=agent_url, api_key=api_key)
|
||||
|
||||
parts: List[Part] = [TextPart(text=task_description)]
|
||||
if context:
|
||||
parts.append(
|
||||
DataPart(
|
||||
data={"context": context},
|
||||
metadata={"type": "context"},
|
||||
)
|
||||
)
|
||||
|
||||
message = Message(role="user", parts=parts)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
queue = await self.client.send_task_streaming(
|
||||
task_id=task_id,
|
||||
message=message,
|
||||
)
|
||||
|
||||
result = await self._wait_for_task_completion(queue, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error executing task via A2A: {e}")
|
||||
raise
|
||||
|
||||
async def _wait_for_task_completion(
|
||||
self, queue: asyncio.Queue, timeout: int
|
||||
) -> str:
|
||||
"""Wait for a task to complete.
|
||||
|
||||
Args:
|
||||
queue: The queue to receive task updates from.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
result = ""
|
||||
try:
|
||||
async def _timeout():
|
||||
await asyncio.sleep(timeout)
|
||||
await queue.put(TimeoutError(f"Task execution timed out after {timeout} seconds"))
|
||||
|
||||
timeout_task = asyncio.create_task(_timeout())
|
||||
|
||||
while True:
|
||||
event = await queue.get()
|
||||
|
||||
if isinstance(event, Exception):
|
||||
raise event
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent):
|
||||
if event.status.state == TaskState.COMPLETED:
|
||||
if event.status.message:
|
||||
for part in event.status.message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
result += part.text
|
||||
break
|
||||
elif event.status.state in [TaskState.FAILED, TaskState.CANCELED]:
|
||||
error_message = "Task failed"
|
||||
if event.status.message:
|
||||
for part in event.status.message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
error_message = part.text
|
||||
raise Exception(error_message)
|
||||
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||
for part in event.artifact.parts:
|
||||
if isinstance(part, TextPart):
|
||||
result += part.text
|
||||
finally:
|
||||
timeout_task.cancel()
|
||||
|
||||
return result
|
||||
|
||||
async def handle_a2a_task(
|
||||
self,
|
||||
task: A2ATask,
|
||||
agent_execute_func: Any,
|
||||
context: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Handle an A2A task.
|
||||
|
||||
Args:
|
||||
task: The A2A task to handle.
|
||||
agent_execute_func: The function to execute the task.
|
||||
context: Additional context for the task.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error handling the task.
|
||||
"""
|
||||
if not self.task_manager:
|
||||
raise ValueError("Task manager is required to handle A2A tasks")
|
||||
|
||||
try:
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.WORKING,
|
||||
)
|
||||
|
||||
task_description = ""
|
||||
task_context = context or ""
|
||||
|
||||
if task.history and task.history[-1].role == "user":
|
||||
message = task.history[-1]
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
task_description += part.text
|
||||
elif isinstance(part, DataPart) and part.data.get("context"):
|
||||
task_context += part.data["context"]
|
||||
|
||||
try:
|
||||
result = await agent_execute_func(task_description, task_context)
|
||||
|
||||
response_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=result)],
|
||||
)
|
||||
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.COMPLETED,
|
||||
message=response_message,
|
||||
)
|
||||
|
||||
artifact = Artifact(
|
||||
name="result",
|
||||
parts=[TextPart(text=result)],
|
||||
)
|
||||
await self.task_manager.add_task_artifact(
|
||||
task_id=task.id,
|
||||
artifact=artifact,
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=str(e))],
|
||||
)
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.FAILED,
|
||||
message=error_message,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling A2A task: {e}")
|
||||
raise
|
||||
470
src/crewai/a2a/client.py
Normal file
470
src/crewai/a2a/client.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
A2A protocol client for CrewAI.
|
||||
|
||||
This module implements the client for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
|
||||
import aiohttp
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.types.a2a import (
|
||||
A2AClientError,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
Artifact,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
MissingAPIKeyError,
|
||||
PushNotificationConfig,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskPushNotificationConfig,
|
||||
TaskQueryParams,
|
||||
TaskSendParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""A2A protocol client implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the A2A client.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A server.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for HTTP requests in seconds.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
if config:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config
|
||||
else:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = A2AConfig()
|
||||
if api_key:
|
||||
self.config.api_key = api_key
|
||||
if timeout:
|
||||
self.config.client_timeout = timeout
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY")
|
||||
self.timeout = self.config.client_timeout
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def send_task(
|
||||
self,
|
||||
task_id: str,
|
||||
message: Message,
|
||||
session_id: Optional[str] = None,
|
||||
accepted_output_modes: Optional[List[str]] = None,
|
||||
push_notification: Optional[PushNotificationConfig] = None,
|
||||
history_length: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Send a task to the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
message: The message to send.
|
||||
session_id: The session ID.
|
||||
accepted_output_modes: The accepted output modes.
|
||||
push_notification: The push notification configuration.
|
||||
history_length: The number of messages to include in the history.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
|
||||
Raises:
|
||||
MissingAPIKeyError: If no API key is provided.
|
||||
A2AClientHTTPError: If there is an HTTP error.
|
||||
A2AClientJSONError: If there is an error parsing the JSON response.
|
||||
A2AClientError: If there is any other error sending the task.
|
||||
"""
|
||||
params = TaskSendParams(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
message=message,
|
||||
acceptedOutputModes=accepted_output_modes,
|
||||
pushNotification=push_notification,
|
||||
historyLength=history_length,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
request = SendTaskRequest(params=params)
|
||||
|
||||
try:
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error sending task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from send task request")
|
||||
|
||||
if isinstance(response.result, dict):
|
||||
return Task.model_validate(response.result)
|
||||
return cast(Task, response.result)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise A2AClientError(f"Task request timed out: {e}")
|
||||
except aiohttp.ClientError as e:
|
||||
if isinstance(e, aiohttp.ClientResponseError):
|
||||
raise A2AClientHTTPError(e.status, str(e))
|
||||
else:
|
||||
raise A2AClientError(f"Client error: {e}")
|
||||
|
||||
async def send_task_streaming(
|
||||
self,
|
||||
task_id: str,
|
||||
message: Message,
|
||||
session_id: Optional[str] = None,
|
||||
accepted_output_modes: Optional[List[str]] = None,
|
||||
push_notification: Optional[PushNotificationConfig] = None,
|
||||
history_length: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Send a task to the A2A server and subscribe to updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
message: The message to send.
|
||||
session_id: The session ID.
|
||||
accepted_output_modes: The accepted output modes.
|
||||
push_notification: The push notification configuration.
|
||||
history_length: The number of messages to include in the history.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
A queue that will receive task updates.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error sending the task.
|
||||
"""
|
||||
params = TaskSendParams(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
message=message,
|
||||
acceptedOutputModes=accepted_output_modes,
|
||||
pushNotification=push_notification,
|
||||
historyLength=history_length,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
asyncio.create_task(
|
||||
self._handle_streaming_response(
|
||||
f"{self.base_url}/v1/tasks/sendSubscribe", params, queue
|
||||
)
|
||||
)
|
||||
|
||||
return queue
|
||||
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task from the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error getting the task.
|
||||
"""
|
||||
params = TaskQueryParams(id=task_id, historyLength=history_length)
|
||||
request = GetTaskRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error getting task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from get task request")
|
||||
|
||||
return cast(Task, response.result)
|
||||
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task on the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error canceling the task.
|
||||
"""
|
||||
params = TaskIdParams(id=task_id)
|
||||
request = CancelTaskRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error canceling task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from cancel task request")
|
||||
|
||||
return cast(Task, response.result)
|
||||
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error setting the push notification.
|
||||
"""
|
||||
params = TaskPushNotificationConfig(id=task_id, pushNotificationConfig=config)
|
||||
request = SetTaskPushNotificationRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(
|
||||
f"Error setting push notification: {response.error.message}"
|
||||
)
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError(
|
||||
"No result returned from set push notification request"
|
||||
)
|
||||
|
||||
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error getting the push notification.
|
||||
"""
|
||||
params = TaskIdParams(id=task_id)
|
||||
request = GetTaskPushNotificationRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(
|
||||
f"Error getting push notification: {response.error.message}"
|
||||
)
|
||||
|
||||
if not response.result:
|
||||
return None
|
||||
|
||||
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||
|
||||
async def _send_jsonrpc_request(
|
||||
self, request: JSONRPCRequest
|
||||
) -> JSONRPCResponse:
|
||||
"""Send a JSON-RPC request to the A2A server.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request.
|
||||
|
||||
Returns:
|
||||
The JSON-RPC response.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error sending the request.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise MissingAPIKeyError(
|
||||
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/jsonrpc",
|
||||
headers=headers,
|
||||
json=request.model_dump(),
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise A2AClientHTTPError(
|
||||
response.status, await response.text()
|
||||
)
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e))
|
||||
|
||||
try:
|
||||
return JSONRPCResponse.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise A2AClientError(f"Invalid response: {e}")
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"Connection error: {e}")
|
||||
except aiohttp.ClientOSError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"OS error: {e}")
|
||||
except aiohttp.ServerDisconnectedError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise A2AClientHTTPError(e.status, str(e))
|
||||
except aiohttp.ClientError as e:
|
||||
raise A2AClientError(f"HTTP error: {e}")
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self,
|
||||
url: str,
|
||||
params: TaskSendParams,
|
||||
queue: asyncio.Queue,
|
||||
) -> None:
|
||||
"""Handle a streaming response from the A2A server.
|
||||
|
||||
Args:
|
||||
url: The URL to send the request to.
|
||||
params: The task send parameters.
|
||||
queue: The queue to put events into.
|
||||
"""
|
||||
if not self.api_key:
|
||||
await queue.put(
|
||||
Exception(
|
||||
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=params.model_dump(),
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
await queue.put(
|
||||
A2AClientHTTPError(response.status, await response.text())
|
||||
)
|
||||
return
|
||||
|
||||
buffer = ""
|
||||
async for line in response.content:
|
||||
line = line.decode("utf-8")
|
||||
buffer += line
|
||||
|
||||
if buffer.endswith("\n\n"):
|
||||
event_data = self._parse_sse_event(buffer)
|
||||
buffer = ""
|
||||
|
||||
if event_data:
|
||||
event_type = event_data.get("event")
|
||||
data = event_data.get("data")
|
||||
|
||||
if event_type == "status":
|
||||
try:
|
||||
event = TaskStatusUpdateEvent.model_validate_json(data)
|
||||
await queue.put(event)
|
||||
|
||||
if event.final:
|
||||
break
|
||||
except ValidationError as e:
|
||||
await queue.put(
|
||||
A2AClientError(f"Invalid status event: {e}")
|
||||
)
|
||||
elif event_type == "artifact":
|
||||
try:
|
||||
event = TaskArtifactUpdateEvent.model_validate_json(data)
|
||||
await queue.put(event)
|
||||
except ValidationError as e:
|
||||
await queue.put(
|
||||
A2AClientError(f"Invalid artifact event: {e}")
|
||||
)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}"))
|
||||
except aiohttp.ClientOSError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}"))
|
||||
except aiohttp.ServerDisconnectedError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}"))
|
||||
except aiohttp.ClientResponseError as e:
|
||||
await queue.put(A2AClientHTTPError(e.status, str(e)))
|
||||
except aiohttp.ClientError as e:
|
||||
await queue.put(A2AClientError(f"HTTP error: {e}"))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
await queue.put(A2AClientError(f"Error handling streaming response: {e}"))
|
||||
|
||||
def _parse_sse_event(self, data: str) -> Dict[str, str]:
|
||||
"""Parse an SSE event.
|
||||
|
||||
Args:
|
||||
data: The SSE event data.
|
||||
|
||||
Returns:
|
||||
A dictionary with the event type and data.
|
||||
"""
|
||||
result = {}
|
||||
for line in data.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("event:"):
|
||||
result["event"] = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
result["data"] = line[5:].strip()
|
||||
|
||||
return result
|
||||
89
src/crewai/a2a/config.py
Normal file
89
src/crewai/a2a/config.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Configuration management for A2A protocol in CrewAI.
|
||||
|
||||
This module provides configuration management for the A2A protocol implementation
|
||||
in CrewAI, including default values and environment variable support.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol."""
|
||||
|
||||
server_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Host to bind the A2A server to.",
|
||||
)
|
||||
server_port: int = Field(
|
||||
default=8000,
|
||||
description="Port to bind the A2A server to.",
|
||||
)
|
||||
enable_cors: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable CORS for the A2A server.",
|
||||
)
|
||||
cors_origins: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="CORS origins to allow. If None, all origins are allowed.",
|
||||
)
|
||||
|
||||
client_timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for A2A client requests in seconds.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for A2A authentication.",
|
||||
)
|
||||
|
||||
task_ttl: int = Field(
|
||||
default=3600,
|
||||
description="Time-to-live for tasks in seconds.",
|
||||
)
|
||||
cleanup_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval for cleaning up expired tasks in seconds.",
|
||||
)
|
||||
max_history_length: int = Field(
|
||||
default=100,
|
||||
description="Maximum number of messages to include in task history.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "A2AConfig":
|
||||
"""Create a configuration from environment variables.
|
||||
|
||||
Environment variables are prefixed with A2A_ and are uppercase.
|
||||
For example, A2A_SERVER_PORT=8080 will set server_port to 8080.
|
||||
|
||||
Returns:
|
||||
A2AConfig: The configuration.
|
||||
"""
|
||||
config_dict: Dict[str, Union[str, int, bool, list[str]]] = {}
|
||||
|
||||
if "A2A_SERVER_HOST" in os.environ:
|
||||
config_dict["server_host"] = os.environ["A2A_SERVER_HOST"]
|
||||
if "A2A_SERVER_PORT" in os.environ:
|
||||
config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"])
|
||||
if "A2A_ENABLE_CORS" in os.environ:
|
||||
config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true"
|
||||
if "A2A_CORS_ORIGINS" in os.environ:
|
||||
config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",")
|
||||
|
||||
if "A2A_CLIENT_TIMEOUT" in os.environ:
|
||||
config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"])
|
||||
if "A2A_API_KEY" in os.environ:
|
||||
config_dict["api_key"] = os.environ["A2A_API_KEY"]
|
||||
|
||||
if "A2A_TASK_TTL" in os.environ:
|
||||
config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"])
|
||||
if "A2A_CLEANUP_INTERVAL" in os.environ:
|
||||
config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"])
|
||||
if "A2A_MAX_HISTORY_LENGTH" in os.environ:
|
||||
config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"])
|
||||
|
||||
return cls(**config_dict)
|
||||
515
src/crewai/a2a/server.py
Normal file
515
src/crewai/a2a/server.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
A2A protocol server for CrewAI.
|
||||
|
||||
This module implements the server for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||
from crewai.types.a2a import (
|
||||
A2ARequest,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
ContentTypeNotSupportedError,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
InvalidRequestError,
|
||||
JSONParseError,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MethodNotFoundError,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskNotCancelableError,
|
||||
TaskNotFoundError,
|
||||
TaskPushNotificationConfig,
|
||||
TaskQueryParams,
|
||||
TaskSendParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
UnsupportedOperationError,
|
||||
)
|
||||
|
||||
|
||||
class A2AServer:
|
||||
"""A2A protocol server implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_manager: Optional[TaskManager] = None,
|
||||
enable_cors: Optional[bool] = None,
|
||||
cors_origins: Optional[List[str]] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the A2A server.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
|
||||
enable_cors: Whether to enable CORS. If None, uses config value.
|
||||
cors_origins: The CORS origins to allow. If None, uses config value.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config or A2AConfig.from_env()
|
||||
|
||||
enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors
|
||||
cors_origins = cors_origins or self.config.cors_origins
|
||||
|
||||
self.app = FastAPI(
|
||||
title="A2A Protocol Server",
|
||||
description="""
|
||||
A2A (Agent-to-Agent) protocol server for CrewAI.
|
||||
|
||||
This server implements Google's A2A protocol specification, enabling interoperability
|
||||
between different agent systems. It provides endpoints for task creation, retrieval,
|
||||
cancellation, and streaming updates.
|
||||
""",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "tasks",
|
||||
"description": "Operations for managing A2A tasks",
|
||||
},
|
||||
{
|
||||
"name": "jsonrpc",
|
||||
"description": "JSON-RPC interface for the A2A protocol",
|
||||
},
|
||||
],
|
||||
)
|
||||
self.task_manager = task_manager or InMemoryTaskManager()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
if enable_cors:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins or ["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/jsonrpc",
|
||||
summary="Handle JSON-RPC requests",
|
||||
description="""
|
||||
Process JSON-RPC requests for the A2A protocol.
|
||||
|
||||
This endpoint handles all JSON-RPC requests for the A2A protocol, including:
|
||||
- SendTask: Create a new task
|
||||
- GetTask: Retrieve a task by ID
|
||||
- CancelTask: Cancel a running task
|
||||
- SetTaskPushNotification: Configure push notifications for a task
|
||||
- GetTaskPushNotification: Retrieve push notification configuration for a task
|
||||
""",
|
||||
response_model=JSONRPCResponse,
|
||||
responses={
|
||||
200: {"description": "Successful response with result or error"},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during processing"},
|
||||
},
|
||||
tags=["jsonrpc"],
|
||||
)
|
||||
async def handle_jsonrpc(request: Request):
|
||||
return await self.handle_jsonrpc(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/send",
|
||||
summary="Send a task to an agent",
|
||||
description="""
|
||||
Create a new task and send it to an agent for execution.
|
||||
|
||||
This endpoint allows clients to send tasks to agents for processing.
|
||||
The task is created with the provided parameters and immediately
|
||||
transitions to the WORKING state. The response includes the created
|
||||
task with its current status.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task created successfully and processing started"},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during task creation or processing"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_send_task(request: Request):
|
||||
return await self.handle_send_task(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/sendSubscribe",
|
||||
summary="Send a task and subscribe to updates",
|
||||
description="""
|
||||
Create a new task and subscribe to status updates via Server-Sent Events (SSE).
|
||||
|
||||
This endpoint allows clients to send tasks to agents and receive real-time
|
||||
updates as the task progresses. The response is a streaming SSE connection
|
||||
that provides status updates and artifact notifications until the task
|
||||
reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED).
|
||||
""",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Streaming response with task updates",
|
||||
"content": {
|
||||
"text/event-stream": {
|
||||
"schema": {"type": "string"},
|
||||
"example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n',
|
||||
}
|
||||
},
|
||||
},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during task creation or processing"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_send_task_subscribe(request: Request):
|
||||
return await self.handle_send_task_subscribe(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/{task_id}/cancel",
|
||||
summary="Cancel a task",
|
||||
description="""
|
||||
Cancel a running task by ID.
|
||||
|
||||
This endpoint allows clients to cancel a task that is currently in progress.
|
||||
The task must be in a non-terminal state (PENDING, WORKING) to be canceled.
|
||||
Once canceled, the task transitions to the CANCELED state and cannot be
|
||||
resumed. The response includes the updated task with its current status.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task canceled successfully and status updated to CANCELED"},
|
||||
404: {"description": "Task not found or already expired"},
|
||||
409: {"description": "Task cannot be canceled (already in terminal state)"},
|
||||
500: {"description": "Internal server error during task cancellation"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_cancel_task(task_id: str, request: Request):
|
||||
return await self.handle_cancel_task(task_id, request)
|
||||
|
||||
@self.app.get(
|
||||
"/v1/tasks/{task_id}",
|
||||
summary="Get task details",
|
||||
description="""
|
||||
Retrieve details of a task by ID.
|
||||
|
||||
This endpoint allows clients to retrieve the current state and details of a task.
|
||||
The response includes the task's status, history, and any associated metadata.
|
||||
Clients can specify the history_length parameter to limit the number of messages
|
||||
included in the response.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task details retrieved successfully with current status"},
|
||||
404: {"description": "Task not found or expired"},
|
||||
500: {"description": "Internal server error during task retrieval"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_get_task(task_id: str, request: Request):
|
||||
return await self.handle_get_task(task_id, request)
|
||||
|
||||
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
|
||||
"""Handle JSON-RPC requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
content=JSONRPCResponse(
|
||||
id=None, error=JSONParseError()
|
||||
).model_dump(),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(body, list):
|
||||
responses = []
|
||||
for req_data in body:
|
||||
response = await self._process_jsonrpc_request(req_data)
|
||||
responses.append(response.model_dump())
|
||||
return JSONResponse(content=responses)
|
||||
else:
|
||||
response = await self._process_jsonrpc_request(body)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
self.logger.exception("Error processing JSON-RPC request")
|
||||
return JSONResponse(
|
||||
content=JSONRPCResponse(
|
||||
id=body.get("id") if isinstance(body, dict) else None,
|
||||
error=InternalError(message="Internal server error"),
|
||||
).model_dump(),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _process_jsonrpc_request(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> JSONRPCResponse:
|
||||
"""Process a JSON-RPC request.
|
||||
|
||||
Args:
|
||||
request_data: The JSON-RPC request data.
|
||||
|
||||
Returns:
|
||||
A JSON-RPC response.
|
||||
"""
|
||||
if not isinstance(request_data, dict) or request_data.get("jsonrpc") != "2.0":
|
||||
return JSONRPCResponse(
|
||||
id=request_data.get("id") if isinstance(request_data, dict) else None,
|
||||
error=InvalidRequestError(),
|
||||
)
|
||||
|
||||
request_id = request_data.get("id")
|
||||
method = request_data.get("method")
|
||||
|
||||
if not method:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InvalidRequestError(message="Method is required"),
|
||||
)
|
||||
|
||||
try:
|
||||
request = A2ARequest.validate_python(request_data)
|
||||
except ValidationError as e:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InvalidParamsError(data=str(e)),
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(request, SendTaskRequest):
|
||||
task = await self._handle_send_task(request.params)
|
||||
return SendTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, GetTaskRequest):
|
||||
task = await self.task_manager.get_task(
|
||||
request.params.id, request.params.historyLength
|
||||
)
|
||||
return GetTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, CancelTaskRequest):
|
||||
task = await self.task_manager.cancel_task(request.params.id)
|
||||
return CancelTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, SetTaskPushNotificationRequest):
|
||||
config = await self.task_manager.set_push_notification(
|
||||
request.params.id, request.params.pushNotificationConfig
|
||||
)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||
)
|
||||
elif isinstance(request, GetTaskPushNotificationRequest):
|
||||
config = await self.task_manager.get_push_notification(
|
||||
request.params.id
|
||||
)
|
||||
if config:
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||
)
|
||||
else:
|
||||
return GetTaskPushNotificationResponse(id=request_id, result=None)
|
||||
elif isinstance(request, SendTaskStreamingRequest):
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=UnsupportedOperationError(
|
||||
message="Streaming requests should be sent to the streaming endpoint"
|
||||
),
|
||||
)
|
||||
else:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=MethodNotFoundError(),
|
||||
)
|
||||
except KeyError:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=TaskNotFoundError(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling {method} request")
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InternalError(message="Internal server error"),
|
||||
)
|
||||
|
||||
async def handle_send_task(self, request: Request) -> JSONResponse:
|
||||
"""Handle send task requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
params = TaskSendParams.model_validate(body)
|
||||
task = await self._handle_send_task(params)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except ValidationError:
|
||||
return JSONResponse(
|
||||
content={"error": "Invalid request format or parameters"},
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception("Error handling send task request")
|
||||
return JSONResponse(
|
||||
content={"error": "Internal server error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _handle_send_task(self, params: TaskSendParams) -> Task:
|
||||
"""Handle send task requests.
|
||||
|
||||
Args:
|
||||
params: The task send parameters.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
task = await self.task_manager.create_task(
|
||||
task_id=params.id,
|
||||
session_id=params.sessionId,
|
||||
message=params.message,
|
||||
metadata=params.metadata,
|
||||
)
|
||||
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=params.id,
|
||||
state=TaskState.WORKING,
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
async def handle_send_task_subscribe(self, request: Request) -> StreamingResponse:
|
||||
"""Handle send task subscribe requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A streaming response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
params = TaskSendParams.model_validate(body)
|
||||
|
||||
task = await self._handle_send_task(params)
|
||||
|
||||
queue = await self.task_manager.subscribe_to_task(params.id)
|
||||
|
||||
return StreamingResponse(
|
||||
self._stream_task_updates(params.id, queue),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
except ValidationError:
|
||||
return JSONResponse(
|
||||
content={"error": "Invalid request format or parameters"},
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception("Error handling send task subscribe request")
|
||||
return JSONResponse(
|
||||
content={"error": "Internal server error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _stream_task_updates(
|
||||
self, task_id: str, queue: asyncio.Queue
|
||||
) -> None:
|
||||
"""Stream task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
queue: The queue to receive updates from.
|
||||
|
||||
Yields:
|
||||
SSE formatted events.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent):
|
||||
event_type = "status"
|
||||
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||
event_type = "artifact"
|
||||
else:
|
||||
event_type = "unknown"
|
||||
|
||||
data = json.dumps(event.model_dump())
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
||||
break
|
||||
finally:
|
||||
await self.task_manager.unsubscribe_from_task(task_id, queue)
|
||||
|
||||
async def handle_get_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||
"""Handle get task requests.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
history_length = request.query_params.get("historyLength")
|
||||
history_length = int(history_length) if history_length else None
|
||||
|
||||
task = await self.task_manager.get_task(task_id, history_length)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling get task request for {task_id}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||
"""Handle cancel task requests.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
task = await self.task_manager.cancel_task(task_id)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling cancel task request for {task_id}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
522
src/crewai/a2a/task_manager.py
Normal file
522
src/crewai/a2a/task_manager.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
A2A protocol task manager for CrewAI.
|
||||
|
||||
This module implements the task manager for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.types.a2a import (
|
||||
Artifact,
|
||||
Message,
|
||||
PushNotificationConfig,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
class TaskManager(ABC):
|
||||
"""Abstract base class for A2A task managers."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Create a new task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
session_id: The session ID.
|
||||
message: The initial message.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
state: TaskState,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Update the status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
state: The new state of the task.
|
||||
message: An optional message to include with the status update.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task status update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_task_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
artifact: Artifact,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskArtifactUpdateEvent:
|
||||
"""Add an artifact to a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
artifact: The artifact to add.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task artifact update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryTaskManager(TaskManager):
|
||||
"""In-memory implementation of the A2A task manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_ttl: Optional[int] = None,
|
||||
cleanup_interval: Optional[int] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the in-memory task manager.
|
||||
|
||||
Args:
|
||||
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
|
||||
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config or A2AConfig.from_env()
|
||||
|
||||
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
|
||||
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
|
||||
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
||||
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
||||
self._task_timestamps: Dict[str, datetime] = {}
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._cleanup_task = None
|
||||
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
except RuntimeError:
|
||||
self._logger.info("No running event loop, periodic cleanup disabled")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Create a new task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
session_id: The session ID.
|
||||
message: The initial message.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
if task_id in self._tasks:
|
||||
return self._tasks[task_id]
|
||||
|
||||
session_id = session_id or uuid4().hex
|
||||
status = TaskStatus(
|
||||
state=TaskState.SUBMITTED,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=None, # Initial state has no previous state
|
||||
)
|
||||
|
||||
task = Task(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
status=status,
|
||||
artifacts=[],
|
||||
history=[message] if message else [],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._tasks[task_id] = task
|
||||
self._task_subscribers[task_id] = set()
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
return task
|
||||
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
if history_length is not None and task.history:
|
||||
task_copy = task.model_copy(deep=True)
|
||||
task_copy.history = task.history[-history_length:]
|
||||
return task_copy
|
||||
return task
|
||||
|
||||
async def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
state: TaskState,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Update the status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
state: The new state of the task.
|
||||
message: An optional message to include with the status update.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task status update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
task = self._tasks[task_id]
|
||||
previous_state = task.status.state if task.status else None
|
||||
|
||||
if previous_state and not TaskState.is_valid_transition(previous_state, state):
|
||||
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
|
||||
|
||||
status = TaskStatus(
|
||||
state=state,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=previous_state,
|
||||
)
|
||||
task.status = status
|
||||
|
||||
if message and task.history is not None:
|
||||
task.history.append(message)
|
||||
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
|
||||
event = TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=status,
|
||||
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
return event
|
||||
|
||||
async def add_task_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
artifact: Artifact,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskArtifactUpdateEvent:
|
||||
"""Add an artifact to a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
artifact: The artifact to add.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task artifact update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
if task.artifacts is None:
|
||||
task.artifacts = []
|
||||
|
||||
if artifact.append and task.artifacts:
|
||||
for existing in task.artifacts:
|
||||
if existing.name == artifact.name:
|
||||
existing.parts.extend(artifact.parts)
|
||||
existing.lastChunk = artifact.lastChunk
|
||||
break
|
||||
else:
|
||||
task.artifacts.append(artifact)
|
||||
else:
|
||||
task.artifacts.append(artifact)
|
||||
|
||||
event = TaskArtifactUpdateEvent(
|
||||
id=task_id,
|
||||
artifact=artifact,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
return event
|
||||
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
|
||||
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
|
||||
await self.update_task_status(task_id, TaskState.CANCELED)
|
||||
|
||||
return task
|
||||
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
self._push_notifications[task_id] = config
|
||||
return config
|
||||
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
return self._push_notifications.get(task_id)
|
||||
|
||||
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
|
||||
"""Subscribe to task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
A queue that will receive task updates.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
self._task_subscribers.setdefault(task_id, set()).add(queue)
|
||||
return queue
|
||||
|
||||
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
|
||||
"""Unsubscribe from task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
queue: The queue to unsubscribe.
|
||||
"""
|
||||
if task_id in self._task_subscribers:
|
||||
self._task_subscribers[task_id].discard(queue)
|
||||
|
||||
async def _notify_subscribers(
|
||||
self,
|
||||
task_id: str,
|
||||
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
|
||||
) -> None:
|
||||
"""Notify subscribers of a task update.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
event: The event to send to subscribers.
|
||||
"""
|
||||
if task_id in self._task_subscribers:
|
||||
for queue in self._task_subscribers[task_id]:
|
||||
await queue.put(event)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Periodically clean up expired tasks."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired_tasks()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error during periodic cleanup: {e}")
|
||||
|
||||
async def _cleanup_expired_tasks(self) -> None:
|
||||
"""Clean up expired tasks."""
|
||||
now = datetime.now()
|
||||
expired_tasks = []
|
||||
|
||||
for task_id, timestamp in self._task_timestamps.items():
|
||||
if (now - timestamp).total_seconds() > self._task_ttl:
|
||||
expired_tasks.append(task_id)
|
||||
|
||||
for task_id in expired_tasks:
|
||||
self._logger.info(f"Cleaning up expired task: {task_id}")
|
||||
self._tasks.pop(task_id, None)
|
||||
self._push_notifications.pop(task_id, None)
|
||||
self._task_timestamps.pop(task_id, None)
|
||||
|
||||
if task_id in self._task_subscribers:
|
||||
previous_state = None
|
||||
if task_id in self._tasks and self._tasks[task_id].status:
|
||||
previous_state = self._tasks[task_id].status.state
|
||||
|
||||
status = TaskStatus(
|
||||
state=TaskState.EXPIRED,
|
||||
timestamp=now,
|
||||
previous_state=previous_state,
|
||||
)
|
||||
event = TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
final=True,
|
||||
)
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
self._task_subscribers.pop(task_id, None)
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.a2a import A2AAgentIntegration
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -131,14 +132,29 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
a2a_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether the agent supports the A2A protocol.",
|
||||
)
|
||||
a2a_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL where the agent's A2A server is hosted.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
_a2a_integration: Optional[A2AAgentIntegration] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
if self.a2a_enabled:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
@@ -294,7 +310,14 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
if self.crew and self.crew.memory:
|
||||
memory = self.crew.contextual_memory.build_context_for_task(task, context)
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew.memory_config,
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._user_memory,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
@@ -348,6 +371,103 @@ class Agent(BaseAgent):
|
||||
result = tool_result["result"]
|
||||
|
||||
return result
|
||||
|
||||
async def execute_task_via_a2a(
|
||||
self,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
agent_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> str:
|
||||
"""Execute a task via the A2A protocol.
|
||||
|
||||
Args:
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
agent_url: The URL of the agent to execute the task. Defaults to self.a2a_url.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If A2A is not enabled or no agent URL is provided.
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
if not self.a2a_enabled:
|
||||
raise ValueError("A2A protocol is not enabled for this agent")
|
||||
|
||||
if not self._a2a_integration:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
url = agent_url or self.a2a_url
|
||||
if not url:
|
||||
raise ValueError("No A2A agent URL provided")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
if asyncio.get_event_loop().is_running():
|
||||
return await self._a2a_integration.execute_task_via_a2a(
|
||||
agent_url=url,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self._a2a_integration.execute_task_via_a2a(
|
||||
agent_url=url,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
))
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error executing task via A2A: {e}")
|
||||
raise
|
||||
|
||||
async def handle_a2a_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Handle an A2A task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the A2A task.
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If A2A is not enabled.
|
||||
Exception: If there is an error handling the task.
|
||||
"""
|
||||
if not self.a2a_enabled:
|
||||
raise ValueError("A2A protocol is not enabled for this agent")
|
||||
|
||||
if not self._a2a_integration:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
# Create a Task object from the task description
|
||||
task = Task(
|
||||
description=task_description,
|
||||
agent=self,
|
||||
expected_output="text", # Default to text output
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.execute_task(task=task, context=context)
|
||||
return result
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error handling A2A task: {e}")
|
||||
raise
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
|
||||
@@ -358,9 +358,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||
train_iteration = self.crew._train_iteration
|
||||
if agent_id in training_data and isinstance(train_iteration, int):
|
||||
training_data[agent_id][train_iteration]["improved_output"] = (
|
||||
result.output
|
||||
)
|
||||
training_data[agent_id][train_iteration][
|
||||
"improved_output"
|
||||
] = result.output
|
||||
training_handler.save(training_data)
|
||||
else:
|
||||
self._printer.print(
|
||||
|
||||
@@ -153,12 +153,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
login_response_json = login_response.json()
|
||||
|
||||
settings = Settings()
|
||||
settings.tool_repository_username = login_response_json["credential"][
|
||||
"username"
|
||||
]
|
||||
settings.tool_repository_password = login_response_json["credential"][
|
||||
"password"
|
||||
]
|
||||
settings.tool_repository_username = login_response_json["credential"]["username"]
|
||||
settings.tool_repository_password = login_response_json["credential"]["password"]
|
||||
settings.dump()
|
||||
|
||||
console.print(
|
||||
@@ -183,7 +179,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
capture_output=False,
|
||||
env=self._build_env_with_credentials(repository_handle),
|
||||
text=True,
|
||||
check=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
if add_package_result.stderr:
|
||||
@@ -208,11 +204,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(settings.tool_repository_username or "")
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(settings.tool_repository_password or "")
|
||||
|
||||
return env
|
||||
|
||||
@@ -25,7 +25,6 @@ from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
@@ -279,13 +278,6 @@ class Crew(BaseModel):
|
||||
)
|
||||
else:
|
||||
self._user_memory = None
|
||||
self.contextual_memory = ContextualMemory(
|
||||
memory_config=self.memory_config,
|
||||
stm=self._short_term_memory,
|
||||
ltm=self._long_term_memory,
|
||||
em=self._entity_memory,
|
||||
um=self._user_memory,
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
@@ -49,13 +49,8 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
|
||||
results = self.storage.search(
|
||||
query,
|
||||
limit,
|
||||
|
||||
@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
@@ -62,10 +62,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
|
||||
def _save_documents(self):
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
self.storage.save(self.chunks)
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
|
||||
@@ -16,7 +16,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -46,7 +46,4 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Save the documents to the storage.
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
self.storage.save(self.chunks)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||
|
||||
@@ -11,7 +10,7 @@ class ContextualMemory:
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: Optional[UserMemory],
|
||||
um: UserMemory,
|
||||
):
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
@@ -22,7 +21,7 @@ class ContextualMemory:
|
||||
self.em = em
|
||||
self.um = um
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
@@ -40,7 +39,7 @@ class ContextualMemory:
|
||||
context.append(self._fetch_user_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -54,7 +53,7 @@ class ContextualMemory:
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task: str) -> Optional[str]:
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -73,7 +72,7 @@ class ContextualMemory:
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -95,8 +94,6 @@ class ContextualMemory:
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
if not self.um:
|
||||
return ""
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
@@ -11,7 +11,7 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
@@ -15,17 +15,8 @@ class LongTermMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
"""Initialize long term memory.
|
||||
|
||||
Args:
|
||||
storage: Optional custom storage instance
|
||||
path: Optional custom path for storage location
|
||||
|
||||
Note:
|
||||
If both storage and path are provided, storage takes precedence
|
||||
"""
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(storage_path=path) if path else LTMSQLiteStorage()
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
|
||||
@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.utilities.paths import get_default_storage_path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
@@ -18,46 +12,17 @@ class BaseRAGStorage(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
storage_path: Optional[Path] = None,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
crew: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the BaseRAGStorage.
|
||||
|
||||
Args:
|
||||
type: Type of storage being used
|
||||
storage_path: Optional custom path for storage location
|
||||
allow_reset: Whether storage can be reset
|
||||
embedder_config: Optional configuration for the embedder
|
||||
crew: Optional crew instance this storage belongs to
|
||||
|
||||
Raises:
|
||||
PermissionError: If storage path is not writable
|
||||
OSError: If storage path cannot be created
|
||||
"""
|
||||
):
|
||||
self.type = type
|
||||
self.storage_path = storage_path if storage_path else get_default_storage_path('rag')
|
||||
|
||||
# Validate storage path
|
||||
try:
|
||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not os.access(self.storage_path.parent, os.W_OK):
|
||||
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
self.crew = crew
|
||||
self.agents = self._initialize_agents()
|
||||
|
||||
def _initialize_agents(self) -> str:
|
||||
"""Initialize agent identifiers for storage.
|
||||
|
||||
Returns:
|
||||
str: Underscore-joined string of sanitized agent role names
|
||||
"""
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
@@ -66,27 +31,12 @@ class BaseRAGStorage(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names.
|
||||
|
||||
Args:
|
||||
role: The agent role name to sanitize
|
||||
|
||||
Returns:
|
||||
str: Sanitized role name safe for use in paths
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage.
|
||||
|
||||
Args:
|
||||
value: The value to store
|
||||
metadata: Additional metadata to store with the value
|
||||
|
||||
Raises:
|
||||
OSError: If there is an error writing to storage
|
||||
"""
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -96,55 +46,25 @@ class BaseRAGStorage(ABC):
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for entries in the storage.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
limit: Maximum number of results to return
|
||||
filter: Optional filter criteria
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of matching entries with their metadata
|
||||
"""
|
||||
) -> List[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage.
|
||||
|
||||
Raises:
|
||||
OSError: If there is an error clearing storage
|
||||
PermissionError: If reset is not allowed
|
||||
"""
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[float]:
|
||||
"""Generate an embedding for the given text and metadata.
|
||||
|
||||
Args:
|
||||
text: Text to generate embedding for
|
||||
metadata: Optional metadata to include in embedding
|
||||
|
||||
Returns:
|
||||
List[float]: Vector embedding of the text
|
||||
|
||||
Raises:
|
||||
ValueError: If text is empty or invalid
|
||||
"""
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self) -> None:
|
||||
"""Initialize the vector db.
|
||||
|
||||
Raises:
|
||||
OSError: If vector db initialization fails
|
||||
"""
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.paths import get_default_storage_path
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class KickoffTaskOutputsSQLiteStorage:
|
||||
@@ -15,26 +13,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: Optional[Path] = None) -> None:
|
||||
"""Initialize kickoff task outputs storage.
|
||||
|
||||
Args:
|
||||
storage_path: Optional custom path for storage location
|
||||
|
||||
Raises:
|
||||
PermissionError: If storage path is not writable
|
||||
OSError: If storage path cannot be created
|
||||
"""
|
||||
self.storage_path = storage_path if storage_path else get_default_storage_path('kickoff')
|
||||
|
||||
# Validate storage path
|
||||
try:
|
||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not os.access(self.storage_path.parent, os.W_OK):
|
||||
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -43,7 +25,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -73,21 +55,9 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Add a task output to storage.
|
||||
|
||||
Args:
|
||||
task: The task whose output is being stored
|
||||
output: The output data from the task
|
||||
task_index: Index of this task in the sequence
|
||||
was_replayed: Whether this was from a replay
|
||||
inputs: Optional input data that led to this output
|
||||
|
||||
Raises:
|
||||
sqlite3.Error: If there is an error saving to database
|
||||
"""
|
||||
):
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -120,7 +90,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
fields = []
|
||||
@@ -149,7 +119,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -185,7 +155,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
Deletes all rows from the latest_kickoff_task_outputs table.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import get_default_storage_path
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
@@ -13,26 +11,10 @@ class LTMSQLiteStorage:
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: Optional[Path] = None) -> None:
|
||||
"""Initialize LTM SQLite storage.
|
||||
|
||||
Args:
|
||||
storage_path: Optional custom path for storage location
|
||||
|
||||
Raises:
|
||||
PermissionError: If storage path is not writable
|
||||
OSError: If storage path cannot be created
|
||||
"""
|
||||
self.storage_path = storage_path if storage_path else get_default_storage_path('ltm')
|
||||
|
||||
# Validate storage path
|
||||
try:
|
||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not os.access(self.storage_path.parent, os.W_OK):
|
||||
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -41,7 +23,7 @@ class LTMSQLiteStorage:
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -69,20 +51,9 @@ class LTMSQLiteStorage:
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
) -> None:
|
||||
"""Save a memory entry to long-term memory.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task this memory relates to
|
||||
metadata: Additional data to store with the memory
|
||||
datetime: Timestamp for when this memory was created
|
||||
score: Relevance score for this memory (higher is more relevant)
|
||||
|
||||
Raises:
|
||||
sqlite3.Error: If there is an error saving to the database
|
||||
"""
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -103,7 +74,7 @@ class LTMSQLiteStorage:
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
@@ -138,7 +109,7 @@ class LTMSQLiteStorage:
|
||||
) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM long_term_memories")
|
||||
conn.commit()
|
||||
|
||||
@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory_type = type
|
||||
self.crew = crew
|
||||
self.memory_config = crew.memory_config if crew else None
|
||||
self.memory_config = crew.memory_config
|
||||
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
@@ -27,10 +27,9 @@ class Mem0Storage(Storage):
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
if self.memory_config and self.memory_config.get("config"):
|
||||
mem0_api_key = self.memory_config.get("config").get("api_key")
|
||||
else:
|
||||
mem0_api_key = os.getenv("MEM0_API_KEY")
|
||||
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
|
||||
"MEM0_API_KEY"
|
||||
)
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
|
||||
@@ -11,6 +11,7 @@ from chromadb.api import ClientAPI
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -39,15 +40,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type,
|
||||
storage_path=None,
|
||||
allow_reset=True,
|
||||
embedder_config=None,
|
||||
crew=None,
|
||||
path=None,
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
super().__init__(type, storage_path, allow_reset, embedder_config, crew)
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
@@ -95,7 +90,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{self.storage_path}/{type}"
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
@@ -157,7 +152,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
try:
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{self.storage_path}/{self.type}")
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
self.app = None
|
||||
self.collection = None
|
||||
except Exception as e:
|
||||
|
||||
@@ -66,6 +66,7 @@ def cache_handler(func):
|
||||
|
||||
|
||||
def crew(func) -> Callable[..., Crew]:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs) -> Crew:
|
||||
instantiated_tasks = []
|
||||
|
||||
@@ -216,5 +216,5 @@ def CrewBase(cls: T) -> T:
|
||||
# Include base class (qual)name in the wrapper class (qual)name.
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
|
||||
|
||||
return cast(T, WrappedClass)
|
||||
|
||||
@@ -373,9 +373,7 @@ class Task(BaseModel):
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else pydantic_output.model_dump_json()
|
||||
if pydantic_output
|
||||
else result
|
||||
else pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
self._save_file(content)
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Type definitions for CrewAI."""
|
||||
|
||||
469
src/crewai/types/a2a.py
Normal file
469
src/crewai/types/a2a.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
A2A protocol types for CrewAI.
|
||||
|
||||
This module implements the A2A (Agent-to-Agent) protocol types as defined by Google.
|
||||
The A2A protocol enables interoperability between different agent systems.
|
||||
|
||||
For more information, see: https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_serializer, model_validator
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Task state in the A2A protocol."""
|
||||
SUBMITTED = 'submitted'
|
||||
WORKING = 'working'
|
||||
INPUT_REQUIRED = 'input-required'
|
||||
COMPLETED = 'completed'
|
||||
CANCELED = 'canceled'
|
||||
FAILED = 'failed'
|
||||
UNKNOWN = 'unknown'
|
||||
EXPIRED = 'expired'
|
||||
|
||||
@classmethod
|
||||
def valid_transitions(cls) -> Dict[str, List[str]]:
|
||||
"""Get valid state transitions.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from state to list of valid next states.
|
||||
"""
|
||||
return {
|
||||
cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||
cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||
cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||
cls.COMPLETED: [], # Terminal state
|
||||
cls.CANCELED: [], # Terminal state
|
||||
cls.FAILED: [], # Terminal state
|
||||
cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||
cls.EXPIRED: [], # Terminal state
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool:
|
||||
"""Check if a state transition is valid.
|
||||
|
||||
Args:
|
||||
from_state: The current state.
|
||||
to_state: The target state.
|
||||
|
||||
Returns:
|
||||
True if the transition is valid, False otherwise.
|
||||
"""
|
||||
if from_state == to_state:
|
||||
return True
|
||||
|
||||
valid_next_states = cls.valid_transitions().get(from_state, [])
|
||||
return to_state in valid_next_states
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
"""Text part in the A2A protocol."""
|
||||
type: Literal['text'] = 'text'
|
||||
text: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
"""File content in the A2A protocol."""
|
||||
name: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
bytes: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_content(self) -> Self:
|
||||
"""Validate file content has either bytes or uri."""
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError(
|
||||
"Either 'bytes' or 'uri' must be present in the file data"
|
||||
)
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
"""File part in the A2A protocol."""
|
||||
type: Literal['file'] = 'file'
|
||||
file: FileContent
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
"""Data part in the A2A protocol."""
|
||||
type: Literal['data'] = 'data'
|
||||
data: Dict[str, Any]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator='type')]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message in the A2A protocol."""
|
||||
role: Literal['user', 'agent']
|
||||
parts: List[Part]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
"""Task status in the A2A protocol."""
|
||||
state: TaskState
|
||||
message: Optional[Message] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
previous_state: Optional[TaskState] = None
|
||||
|
||||
@field_serializer('timestamp')
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
"""Serialize datetime to ISO format."""
|
||||
return dt.isoformat()
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_state_transition(self) -> Self:
|
||||
"""Validate state transition."""
|
||||
if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state):
|
||||
raise ValueError(
|
||||
f"Invalid state transition from {self.previous_state} to {self.state}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""Artifact in the A2A protocol."""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[Part]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
index: int = 0
|
||||
append: Optional[bool] = None
|
||||
lastChunk: Optional[bool] = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Task in the A2A protocol."""
|
||||
id: str
|
||||
sessionId: Optional[str] = None
|
||||
status: TaskStatus
|
||||
artifacts: Optional[List[Artifact]] = None
|
||||
history: Optional[List[Message]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
"""Task status update event in the A2A protocol."""
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
"""Task artifact update event in the A2A protocol."""
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
"""Authentication information in the A2A protocol."""
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
schemes: List[str]
|
||||
credentials: Optional[str] = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""Push notification configuration in the A2A protocol."""
|
||||
url: str
|
||||
token: Optional[str] = None
|
||||
authentication: Optional[AuthenticationInfo] = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
"""Task ID parameters in the A2A protocol."""
|
||||
id: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
"""Task query parameters in the A2A protocol."""
|
||||
historyLength: Optional[int] = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
"""Task send parameters in the A2A protocol."""
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: Optional[PushNotificationConfig] = None
|
||||
historyLength: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
"""Task push notification configuration in the A2A protocol."""
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
"""JSON-RPC message in the A2A protocol."""
|
||||
jsonrpc: Literal['2.0'] = '2.0'
|
||||
id: Optional[Union[int, str]] = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
"""JSON-RPC request in the A2A protocol."""
|
||||
method: str
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""JSON-RPC error in the A2A protocol."""
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
"""JSON-RPC response in the A2A protocol."""
|
||||
result: Optional[Any] = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
"""Send task request in the A2A protocol."""
|
||||
method: Literal['tasks/send'] = 'tasks/send'
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
"""Send task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
"""Send task streaming request in the A2A protocol."""
|
||||
method: Literal['tasks/sendSubscribe'] = 'tasks/sendSubscribe'
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
"""Send task streaming response in the A2A protocol."""
|
||||
result: Optional[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]] = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
"""Get task request in the A2A protocol."""
|
||||
method: Literal['tasks/get'] = 'tasks/get'
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
"""Get task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
"""Cancel task request in the A2A protocol."""
|
||||
method: Literal['tasks/cancel'] = 'tasks/cancel'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
"""Cancel task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""Set task push notification request in the A2A protocol."""
|
||||
method: Literal['tasks/pushNotification/set'] = 'tasks/pushNotification/set'
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""Set task push notification response in the A2A protocol."""
|
||||
result: Optional[TaskPushNotificationConfig] = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""Get task push notification request in the A2A protocol."""
|
||||
method: Literal['tasks/pushNotification/get'] = 'tasks/pushNotification/get'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""Get task push notification response in the A2A protocol."""
|
||||
result: Optional[TaskPushNotificationConfig] = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
"""Task resubscription request in the A2A protocol."""
|
||||
method: Literal['tasks/resubscribe'] = 'tasks/resubscribe'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator='method'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
"""JSON parse error in the A2A protocol."""
|
||||
code: int = -32700
|
||||
message: str = 'Invalid JSON payload'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
"""Invalid request error in the A2A protocol."""
|
||||
code: int = -32600
|
||||
message: str = 'Request payload validation error'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
"""Method not found error in the A2A protocol."""
|
||||
code: int = -32601
|
||||
message: str = 'Method not found'
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
"""Invalid parameters error in the A2A protocol."""
|
||||
code: int = -32602
|
||||
message: str = 'Invalid parameters'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
"""Internal error in the A2A protocol."""
|
||||
code: int = -32603
|
||||
message: str = 'Internal error'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
"""Task not found error in the A2A protocol."""
|
||||
code: int = -32001
|
||||
message: str = 'Task not found'
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
"""Task not cancelable error in the A2A protocol."""
|
||||
code: int = -32002
|
||||
message: str = 'Task cannot be canceled'
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
"""Push notification not supported error in the A2A protocol."""
|
||||
code: int = -32003
|
||||
message: str = 'Push Notification is not supported'
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
"""Unsupported operation error in the A2A protocol."""
|
||||
code: int = -32004
|
||||
message: str = 'This operation is not supported'
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
"""Content type not supported error in the A2A protocol."""
|
||||
code: int = -32005
|
||||
message: str = 'Incompatible content types'
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
"""Agent provider in the A2A protocol."""
|
||||
organization: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
"""Agent capabilities in the A2A protocol."""
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
"""Agent authentication in the A2A protocol."""
|
||||
schemes: List[str]
|
||||
credentials: Optional[str] = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
"""Agent skill in the A2A protocol."""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
examples: Optional[List[str]] = None
|
||||
inputModes: Optional[List[str]] = None
|
||||
outputModes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
"""Agent card in the A2A protocol."""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
url: str
|
||||
provider: Optional[AgentProvider] = None
|
||||
version: str
|
||||
documentationUrl: Optional[str] = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: Optional[AgentAuthentication] = None
|
||||
defaultInputModes: List[str] = ['text']
|
||||
defaultOutputModes: List[str] = ['text']
|
||||
skills: List[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
"""Base exception for A2A client errors."""
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
"""HTTP error in the A2A client."""
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f'HTTP Error {status_code}: {message}')
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
"""JSON error in the A2A client."""
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f'JSON Error: {message}')
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
||||
pass
|
||||
@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider", "")
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
|
||||
@@ -38,13 +38,12 @@ class EmbeddingConfigurator:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
|
||||
embedding_function = self.embedding_functions.get(provider, None)
|
||||
if not embedding_function:
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
return embedding_function(config, model_name)
|
||||
return self.embedding_functions[provider](config, model_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
|
||||
@@ -22,26 +22,3 @@ def get_project_directory_name():
|
||||
cwd = Path.cwd()
|
||||
project_directory_name = cwd.name
|
||||
return project_directory_name
|
||||
|
||||
def get_default_storage_path(storage_type: str) -> Path:
|
||||
"""Returns the default storage path for a given storage type.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage ('ltm', 'kickoff', 'rag')
|
||||
|
||||
Returns:
|
||||
Path: Default storage path for the specified type
|
||||
|
||||
Raises:
|
||||
ValueError: If storage_type is not recognized
|
||||
"""
|
||||
base_path = db_storage_path()
|
||||
|
||||
if storage_type == 'ltm':
|
||||
return base_path / 'latest_long_term_memories.db'
|
||||
elif storage_type == 'kickoff':
|
||||
return base_path / 'latest_kickoff_task_outputs.db'
|
||||
elif storage_type == 'rag':
|
||||
return base_path
|
||||
else:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
1
tests/a2a/__init__.py
Normal file
1
tests/a2a/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the A2A protocol implementation."""
|
||||
240
tests/a2a/test_a2a_integration.py
Normal file
240
tests/a2a/test_a2a_integration.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
||||
from crewai.task import Task
|
||||
from crewai.types.a2a import (
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Create an agent with A2A enabled."""
|
||||
return Agent(
|
||||
role="test_agent",
|
||||
goal="Test A2A protocol",
|
||||
backstory="I am a test agent",
|
||||
a2a_enabled=True,
|
||||
a2a_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task():
|
||||
"""Create a task."""
|
||||
return Task(
|
||||
description="Test task",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_task():
|
||||
"""Create an A2A task."""
|
||||
return A2ATask(
|
||||
id="test_task_id",
|
||||
history=[
|
||||
Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_integration():
|
||||
"""Create an A2A integration."""
|
||||
return A2AAgentIntegration()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_client():
|
||||
"""Create an A2A client."""
|
||||
return A2AClient(base_url="http://localhost:8000", api_key="test_api_key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_manager():
|
||||
"""Create a task manager."""
|
||||
return InMemoryTaskManager()
|
||||
|
||||
|
||||
class TestA2AIntegration:
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
def test_agent_a2a_attributes(self, agent):
|
||||
"""Test that the agent has A2A attributes."""
|
||||
assert agent.a2a_enabled is True
|
||||
assert agent.a2a_url == "http://localhost:8000"
|
||||
assert agent._a2a_integration is not None
|
||||
|
||||
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
||||
def test_execute_task_via_a2a(self, mock_execute, agent):
|
||||
"""Test executing a task via A2A."""
|
||||
mock_execute.return_value = "Task result"
|
||||
|
||||
result = asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once_with(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
api_key=None,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
@patch("crewai.agent.Agent.execute_task")
|
||||
def test_handle_a2a_task(self, mock_execute, agent):
|
||||
"""Test handling an A2A task."""
|
||||
mock_execute.return_value = "Task result"
|
||||
|
||||
result = asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once()
|
||||
args, kwargs = mock_execute.call_args
|
||||
assert kwargs["context"] == "Test context"
|
||||
assert kwargs["task"].description == "Test task"
|
||||
|
||||
def test_a2a_disabled(self, agent):
|
||||
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
||||
agent.a2a_enabled = False
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
def test_no_agent_url(self, agent):
|
||||
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
|
||||
agent.a2a_url = None
|
||||
|
||||
with pytest.raises(ValueError, match="No A2A agent URL provided"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestA2AAgentIntegration:
|
||||
"""Tests for the A2AAgentIntegration class."""
|
||||
|
||||
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
|
||||
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
|
||||
"""Test executing a task via A2A."""
|
||||
queue = asyncio.Queue()
|
||||
await queue.put(
|
||||
TaskStatusUpdateEvent(
|
||||
id="test_task_id",
|
||||
status=TaskStatus(
|
||||
state=TaskState.COMPLETED,
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text="Task result")],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_task.return_value = queue
|
||||
|
||||
result = await a2a_integration.execute_task_via_a2a(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_send_task.assert_called_once()
|
||||
|
||||
|
||||
class TestA2AServer:
|
||||
"""Tests for the A2AServer class."""
|
||||
|
||||
@patch("fastapi.FastAPI.post")
|
||||
def test_server_initialization(self, mock_post, task_manager):
|
||||
"""Test server initialization."""
|
||||
server = A2AServer(task_manager=task_manager)
|
||||
assert server.task_manager == task_manager
|
||||
assert server.app is not None
|
||||
assert mock_post.call_count == 4 # 4 endpoints registered
|
||||
|
||||
|
||||
class TestA2AClient:
|
||||
"""Tests for the A2AClient class."""
|
||||
|
||||
@patch("crewai.a2a.client.A2AClient._send_jsonrpc_request")
|
||||
async def test_send_task(self, mock_send_request, a2a_client):
|
||||
"""Test sending a task."""
|
||||
mock_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="test_request_id",
|
||||
result=A2ATask(
|
||||
id="test_task_id",
|
||||
sessionId="test_session_id",
|
||||
status=TaskStatus(
|
||||
state=TaskState.SUBMITTED,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
history=[
|
||||
Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_request.return_value = mock_response
|
||||
|
||||
task = await a2a_client.send_task(
|
||||
task_id="test_task_id",
|
||||
message=Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
),
|
||||
session_id="test_session_id",
|
||||
)
|
||||
|
||||
assert task.id == "test_task_id"
|
||||
assert task.history[0].role == "user"
|
||||
assert task.history[0].parts[0].text == "Test task description"
|
||||
mock_send_request.assert_called_once()
|
||||
@@ -28,10 +28,9 @@ def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with (
|
||||
patch.object(tool_command, "login") as mock_login,
|
||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
||||
):
|
||||
with patch.object(tool_command, "login") as mock_login, patch(
|
||||
"sys.stdout", new=StringIO()
|
||||
) as fake_out:
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
@@ -83,7 +82,7 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
env=unittest.mock.ANY,
|
||||
env=unittest.mock.ANY
|
||||
)
|
||||
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import KickoffTaskOutputsSQLiteStorage
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import get_default_storage_path
|
||||
|
||||
class MockRAGStorage(BaseRAGStorage):
|
||||
"""Mock implementation of BaseRAGStorage for testing."""
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
return role.lower()
|
||||
|
||||
def save(self, value, metadata):
|
||||
pass
|
||||
|
||||
def search(self, query, limit=3, filter=None, score_threshold=0.35):
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def _generate_embedding(self, text, metadata=None):
|
||||
return []
|
||||
|
||||
def _initialize_app(self):
|
||||
pass
|
||||
|
||||
def test_default_storage_paths():
|
||||
"""Test that default storage paths are created correctly."""
|
||||
ltm_path = get_default_storage_path('ltm')
|
||||
kickoff_path = get_default_storage_path('kickoff')
|
||||
rag_path = get_default_storage_path('rag')
|
||||
|
||||
assert str(ltm_path).endswith('latest_long_term_memories.db')
|
||||
assert str(kickoff_path).endswith('latest_kickoff_task_outputs.db')
|
||||
assert isinstance(rag_path, Path)
|
||||
|
||||
def test_custom_storage_paths():
|
||||
"""Test that custom storage paths are respected."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
custom_path = Path(temp_dir) / 'custom.db'
|
||||
|
||||
ltm = LTMSQLiteStorage(storage_path=custom_path)
|
||||
assert ltm.storage_path == custom_path
|
||||
|
||||
kickoff = KickoffTaskOutputsSQLiteStorage(storage_path=custom_path)
|
||||
assert kickoff.storage_path == custom_path
|
||||
|
||||
rag = MockRAGStorage('test', storage_path=custom_path)
|
||||
assert rag.storage_path == custom_path
|
||||
|
||||
def test_directory_creation():
|
||||
"""Test that storage directories are created automatically."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_dir = Path(temp_dir) / 'test_storage'
|
||||
storage_path = test_dir / 'test.db'
|
||||
|
||||
assert not test_dir.exists()
|
||||
LTMSQLiteStorage(storage_path=storage_path)
|
||||
assert test_dir.exists()
|
||||
|
||||
def test_permission_error():
|
||||
"""Test that permission errors are handled correctly."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_dir = Path(temp_dir) / 'readonly'
|
||||
test_dir.mkdir()
|
||||
os.chmod(test_dir, 0o444) # Read-only
|
||||
|
||||
storage_path = test_dir / 'test.db'
|
||||
with pytest.raises((PermissionError, OSError)) as exc_info:
|
||||
LTMSQLiteStorage(storage_path=storage_path)
|
||||
# Verify that the error message mentions permission
|
||||
assert "permission" in str(exc_info.value).lower()
|
||||
|
||||
def test_invalid_path():
|
||||
"""Test that invalid paths raise appropriate errors."""
|
||||
with pytest.raises(OSError):
|
||||
# Try to create storage in a non-existent root directory
|
||||
LTMSQLiteStorage(storage_path=Path('/nonexistent/dir/test.db'))
|
||||
68
uv.lock
generated
68
uv.lock
generated
@@ -1,18 +1,10 @@
|
||||
version = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
|
||||
"python_full_version >= '3.12.4'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -308,7 +300,7 @@ name = "build"
|
||||
version = "1.2.2.post1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "colorama", marker = "os_name == 'nt'" },
|
||||
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyproject-hooks" },
|
||||
@@ -543,7 +535,7 @@ name = "click"
|
||||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
@@ -650,6 +642,7 @@ tools = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "cairosvg" },
|
||||
{ name = "crewai-tools" },
|
||||
{ name = "mkdocs" },
|
||||
{ name = "mkdocs-material" },
|
||||
{ name = "mkdocs-material-extensions" },
|
||||
@@ -703,6 +696,7 @@ requires-dist = [
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "cairosvg", specifier = ">=2.7.1" },
|
||||
{ name = "crewai-tools", specifier = ">=0.17.0" },
|
||||
{ name = "mkdocs", specifier = ">=1.4.3" },
|
||||
{ name = "mkdocs-material", specifier = ">=9.5.7" },
|
||||
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
|
||||
@@ -2468,7 +2462,7 @@ version = "1.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "ghp-import" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "markdown" },
|
||||
@@ -2649,7 +2643,7 @@ version = "2.10.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pygments" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
|
||||
@@ -2896,7 +2890,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.1.0.70"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
|
||||
@@ -2923,9 +2917,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.4.5.107"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
|
||||
@@ -2936,7 +2930,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.1.0.106"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
|
||||
@@ -3486,7 +3480,7 @@ name = "portalocker"
|
||||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
|
||||
wheels = [
|
||||
@@ -5028,19 +5022,19 @@ dependencies = [
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
@@ -5087,7 +5081,7 @@ name = "tqdm"
|
||||
version = "4.66.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
|
||||
wheels = [
|
||||
@@ -5130,7 +5124,7 @@ version = "0.27.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "attrs" },
|
||||
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "idna" },
|
||||
{ name = "outcome" },
|
||||
@@ -5161,7 +5155,7 @@ name = "triton"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },
|
||||
|
||||
Reference in New Issue
Block a user