Compare commits

...

3 Commits

Author SHA1 Message Date
Lorenze Jay
92505685e1 Merge branch 'main' into lorenze/imp-pydantic 2026-01-27 15:00:35 -08:00
lorenzejay
ae37e88f53 fix missing import 2026-01-27 13:26:23 -08:00
lorenzejay
02f6926aa0 refactor: update event type definitions to use Literal for type safety
- Changed event type definitions across multiple event classes to use Literal for improved type safety and clarity.
- Updated the  definition in  to utilize Annotated for better schema representation.
- Ensured consistency in type definitions for various events, enhancing the robustness of event handling in the CrewAI framework.
2026-01-27 13:23:26 -08:00
24 changed files with 153 additions and 147 deletions

View File

@@ -1,3 +1,7 @@
from typing import Annotated
from pydantic import Field
from crewai.events.types.a2a_events import ( from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent, A2AAgentCardFetchedEvent,
A2AArtifactReceivedEvent, A2AArtifactReceivedEvent,
@@ -102,7 +106,7 @@ from crewai.events.types.tool_usage_events import (
) )
EventTypes = ( EventTypes = Annotated[
A2AAgentCardFetchedEvent A2AAgentCardFetchedEvent
| A2AArtifactReceivedEvent | A2AArtifactReceivedEvent
| A2AAuthenticationFailedEvent | A2AAuthenticationFailedEvent
@@ -180,5 +184,6 @@ EventTypes = (
| MCPConnectionFailedEvent | MCPConnectionFailedEvent
| MCPToolExecutionStartedEvent | MCPToolExecutionStartedEvent
| MCPToolExecutionCompletedEvent | MCPToolExecutionCompletedEvent
| MCPToolExecutionFailedEvent | MCPToolExecutionFailedEvent,
) Field(discriminator="type"),
]

View File

@@ -73,7 +73,7 @@ class A2ADelegationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_delegation_started" type: Literal["a2a_delegation_started"] = "a2a_delegation_started"
endpoint: str endpoint: str
task_description: str task_description: str
agent_id: str agent_id: str
@@ -106,7 +106,7 @@ class A2ADelegationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_delegation_completed" type: Literal["a2a_delegation_completed"] = "a2a_delegation_completed"
status: str status: str
result: str | None = None result: str | None = None
error: str | None = None error: str | None = None
@@ -140,7 +140,7 @@ class A2AConversationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_conversation_started" type: Literal["a2a_conversation_started"] = "a2a_conversation_started"
agent_id: str agent_id: str
endpoint: str endpoint: str
context_id: str | None = None context_id: str | None = None
@@ -171,7 +171,7 @@ class A2AMessageSentEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_message_sent" type: Literal["a2a_message_sent"] = "a2a_message_sent"
message: str message: str
turn_number: int turn_number: int
context_id: str | None = None context_id: str | None = None
@@ -203,7 +203,7 @@ class A2AResponseReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_response_received" type: Literal["a2a_response_received"] = "a2a_response_received"
response: str response: str
turn_number: int turn_number: int
context_id: str | None = None context_id: str | None = None
@@ -237,7 +237,7 @@ class A2AConversationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_conversation_completed" type: Literal["a2a_conversation_completed"] = "a2a_conversation_completed"
status: Literal["completed", "failed"] status: Literal["completed", "failed"]
final_result: str | None = None final_result: str | None = None
error: str | None = None error: str | None = None
@@ -263,7 +263,7 @@ class A2APollingStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_polling_started" type: Literal["a2a_polling_started"] = "a2a_polling_started"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
polling_interval: float polling_interval: float
@@ -286,7 +286,7 @@ class A2APollingStatusEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_polling_status" type: Literal["a2a_polling_status"] = "a2a_polling_status"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
state: str state: str
@@ -309,7 +309,7 @@ class A2APushNotificationRegisteredEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_push_notification_registered" type: Literal["a2a_push_notification_registered"] = "a2a_push_notification_registered"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
callback_url: str callback_url: str
@@ -334,7 +334,7 @@ class A2APushNotificationReceivedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_push_notification_received" type: Literal["a2a_push_notification_received"] = "a2a_push_notification_received"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
state: str state: str
@@ -359,7 +359,7 @@ class A2APushNotificationSentEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_push_notification_sent" type: Literal["a2a_push_notification_sent"] = "a2a_push_notification_sent"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
callback_url: str callback_url: str
@@ -381,7 +381,7 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_push_notification_timeout" type: Literal["a2a_push_notification_timeout"] = "a2a_push_notification_timeout"
task_id: str task_id: str
context_id: str | None = None context_id: str | None = None
timeout_seconds: float timeout_seconds: float
@@ -405,7 +405,7 @@ class A2AStreamingStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_streaming_started" type: Literal["a2a_streaming_started"] = "a2a_streaming_started"
task_id: str | None = None task_id: str | None = None
context_id: str | None = None context_id: str | None = None
endpoint: str endpoint: str
@@ -434,7 +434,7 @@ class A2AStreamingChunkEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_streaming_chunk" type: Literal["a2a_streaming_chunk"] = "a2a_streaming_chunk"
task_id: str | None = None task_id: str | None = None
context_id: str | None = None context_id: str | None = None
chunk: str chunk: str
@@ -462,7 +462,7 @@ class A2AAgentCardFetchedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_agent_card_fetched" type: Literal["a2a_agent_card_fetched"] = "a2a_agent_card_fetched"
endpoint: str endpoint: str
a2a_agent_name: str | None = None a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None agent_card: dict[str, Any] | None = None
@@ -486,7 +486,7 @@ class A2AAuthenticationFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_authentication_failed" type: Literal["a2a_authentication_failed"] = "a2a_authentication_failed"
endpoint: str endpoint: str
auth_type: str | None = None auth_type: str | None = None
error: str error: str
@@ -517,7 +517,7 @@ class A2AArtifactReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use. extensions: List of A2A extension URIs in use.
""" """
type: str = "a2a_artifact_received" type: Literal["a2a_artifact_received"] = "a2a_artifact_received"
task_id: str task_id: str
artifact_id: str artifact_id: str
artifact_name: str | None = None artifact_name: str | None = None
@@ -550,7 +550,7 @@ class A2AConnectionErrorEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_connection_error" type: Literal["a2a_connection_error"] = "a2a_connection_error"
endpoint: str endpoint: str
error: str error: str
error_type: str | None = None error_type: str | None = None
@@ -571,7 +571,7 @@ class A2AServerTaskStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_server_task_started" type: Literal["a2a_server_task_started"] = "a2a_server_task_started"
task_id: str task_id: str
context_id: str context_id: str
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
@@ -587,7 +587,7 @@ class A2AServerTaskCompletedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_server_task_completed" type: Literal["a2a_server_task_completed"] = "a2a_server_task_completed"
task_id: str task_id: str
context_id: str context_id: str
result: str result: str
@@ -603,7 +603,7 @@ class A2AServerTaskCanceledEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_server_task_canceled" type: Literal["a2a_server_task_canceled"] = "a2a_server_task_canceled"
task_id: str task_id: str
context_id: str context_id: str
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
@@ -619,7 +619,7 @@ class A2AServerTaskFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs. metadata: Custom A2A metadata key-value pairs.
""" """
type: str = "a2a_server_task_failed" type: Literal["a2a_server_task_failed"] = "a2a_server_task_failed"
task_id: str task_id: str
context_id: str context_id: str
error: str error: str
@@ -634,7 +634,7 @@ class A2AParallelDelegationStartedEvent(A2AEventBase):
task_description: Description of the task being delegated. task_description: Description of the task being delegated.
""" """
type: str = "a2a_parallel_delegation_started" type: Literal["a2a_parallel_delegation_started"] = "a2a_parallel_delegation_started"
endpoints: list[str] endpoints: list[str]
task_description: str task_description: str
@@ -649,7 +649,7 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
results: Summary of results from each agent. results: Summary of results from each agent.
""" """
type: str = "a2a_parallel_delegation_completed" type: Literal["a2a_parallel_delegation_completed"] = "a2a_parallel_delegation_completed"
endpoints: list[str] endpoints: list[str]
success_count: int success_count: int
failure_count: int failure_count: int

View File

@@ -2,8 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from typing import Any, Literal
from typing import Any
from pydantic import ConfigDict, model_validator from pydantic import ConfigDict, model_validator
@@ -18,9 +17,9 @@ class AgentExecutionStartedEvent(BaseEvent):
agent: BaseAgent agent: BaseAgent
task: Any task: Any
tools: Sequence[BaseTool | CrewStructuredTool] | None tools: list[BaseTool | CrewStructuredTool] | None
task_prompt: str task_prompt: str
type: str = "agent_execution_started" type: Literal["agent_execution_started"] = "agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -44,7 +43,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
agent: BaseAgent agent: BaseAgent
task: Any task: Any
output: str output: str
type: str = "agent_execution_completed" type: Literal["agent_execution_completed"] = "agent_execution_completed"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -68,7 +67,7 @@ class AgentExecutionErrorEvent(BaseEvent):
agent: BaseAgent agent: BaseAgent
task: Any task: Any
error: str error: str
type: str = "agent_execution_error" type: Literal["agent_execution_error"] = "agent_execution_error"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -91,9 +90,9 @@ class LiteAgentExecutionStartedEvent(BaseEvent):
"""Event emitted when a LiteAgent starts executing""" """Event emitted when a LiteAgent starts executing"""
agent_info: dict[str, Any] agent_info: dict[str, Any]
tools: Sequence[BaseTool | CrewStructuredTool] | None tools: list[BaseTool | CrewStructuredTool] | None
messages: str | list[dict[str, str]] messages: str | list[dict[str, str]]
type: str = "lite_agent_execution_started" type: Literal["lite_agent_execution_started"] = "lite_agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -103,7 +102,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
agent_info: dict[str, Any] agent_info: dict[str, Any]
output: str output: str
type: str = "lite_agent_execution_completed" type: Literal["lite_agent_execution_completed"] = "lite_agent_execution_completed"
class LiteAgentExecutionErrorEvent(BaseEvent): class LiteAgentExecutionErrorEvent(BaseEvent):
@@ -111,7 +110,7 @@ class LiteAgentExecutionErrorEvent(BaseEvent):
agent_info: dict[str, Any] agent_info: dict[str, Any]
error: str error: str
type: str = "lite_agent_execution_error" type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error"
# Agent Eval events # Agent Eval events
@@ -120,7 +119,7 @@ class AgentEvaluationStartedEvent(BaseEvent):
agent_role: str agent_role: str
task_id: str | None = None task_id: str | None = None
iteration: int iteration: int
type: str = "agent_evaluation_started" type: Literal["agent_evaluation_started"] = "agent_evaluation_started"
class AgentEvaluationCompletedEvent(BaseEvent): class AgentEvaluationCompletedEvent(BaseEvent):
@@ -130,7 +129,7 @@ class AgentEvaluationCompletedEvent(BaseEvent):
iteration: int iteration: int
metric_category: Any metric_category: Any
score: Any score: Any
type: str = "agent_evaluation_completed" type: Literal["agent_evaluation_completed"] = "agent_evaluation_completed"
class AgentEvaluationFailedEvent(BaseEvent): class AgentEvaluationFailedEvent(BaseEvent):
@@ -139,4 +138,4 @@ class AgentEvaluationFailedEvent(BaseEvent):
task_id: str | None = None task_id: str | None = None
iteration: int iteration: int
error: str error: str
type: str = "agent_evaluation_failed" type: Literal["agent_evaluation_failed"] = "agent_evaluation_failed"

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -40,14 +40,14 @@ class CrewKickoffStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts execution""" """Event emitted when a crew starts execution"""
inputs: dict[str, Any] | None inputs: dict[str, Any] | None
type: str = "crew_kickoff_started" type: Literal["crew_kickoff_started"] = "crew_kickoff_started"
class CrewKickoffCompletedEvent(CrewBaseEvent): class CrewKickoffCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes execution""" """Event emitted when a crew completes execution"""
output: Any output: Any
type: str = "crew_kickoff_completed" type: Literal["crew_kickoff_completed"] = "crew_kickoff_completed"
total_tokens: int = 0 total_tokens: int = 0
@@ -55,7 +55,7 @@ class CrewKickoffFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete execution""" """Event emitted when a crew fails to complete execution"""
error: str error: str
type: str = "crew_kickoff_failed" type: Literal["crew_kickoff_failed"] = "crew_kickoff_failed"
class CrewTrainStartedEvent(CrewBaseEvent): class CrewTrainStartedEvent(CrewBaseEvent):
@@ -64,7 +64,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
n_iterations: int n_iterations: int
filename: str filename: str
inputs: dict[str, Any] | None inputs: dict[str, Any] | None
type: str = "crew_train_started" type: Literal["crew_train_started"] = "crew_train_started"
class CrewTrainCompletedEvent(CrewBaseEvent): class CrewTrainCompletedEvent(CrewBaseEvent):
@@ -72,14 +72,14 @@ class CrewTrainCompletedEvent(CrewBaseEvent):
n_iterations: int n_iterations: int
filename: str filename: str
type: str = "crew_train_completed" type: Literal["crew_train_completed"] = "crew_train_completed"
class CrewTrainFailedEvent(CrewBaseEvent): class CrewTrainFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete training""" """Event emitted when a crew fails to complete training"""
error: str error: str
type: str = "crew_train_failed" type: Literal["crew_train_failed"] = "crew_train_failed"
class CrewTestStartedEvent(CrewBaseEvent): class CrewTestStartedEvent(CrewBaseEvent):
@@ -88,20 +88,20 @@ class CrewTestStartedEvent(CrewBaseEvent):
n_iterations: int n_iterations: int
eval_llm: str | Any | None eval_llm: str | Any | None
inputs: dict[str, Any] | None inputs: dict[str, Any] | None
type: str = "crew_test_started" type: Literal["crew_test_started"] = "crew_test_started"
class CrewTestCompletedEvent(CrewBaseEvent): class CrewTestCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes testing""" """Event emitted when a crew completes testing"""
type: str = "crew_test_completed" type: Literal["crew_test_completed"] = "crew_test_completed"
class CrewTestFailedEvent(CrewBaseEvent): class CrewTestFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete testing""" """Event emitted when a crew fails to complete testing"""
error: str error: str
type: str = "crew_test_failed" type: Literal["crew_test_failed"] = "crew_test_failed"
class CrewTestResultEvent(CrewBaseEvent): class CrewTestResultEvent(CrewBaseEvent):
@@ -110,4 +110,4 @@ class CrewTestResultEvent(CrewBaseEvent):
quality: float quality: float
execution_duration: float execution_duration: float
model: str model: str
type: str = "crew_test_result" type: Literal["crew_test_result"] = "crew_test_result"

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@@ -17,14 +17,14 @@ class FlowStartedEvent(FlowEvent):
flow_name: str flow_name: str
inputs: dict[str, Any] | None = None inputs: dict[str, Any] | None = None
type: str = "flow_started" type: Literal["flow_started"] = "flow_started"
class FlowCreatedEvent(FlowEvent): class FlowCreatedEvent(FlowEvent):
"""Event emitted when a flow is created""" """Event emitted when a flow is created"""
flow_name: str flow_name: str
type: str = "flow_created" type: Literal["flow_created"] = "flow_created"
class MethodExecutionStartedEvent(FlowEvent): class MethodExecutionStartedEvent(FlowEvent):
@@ -34,7 +34,7 @@ class MethodExecutionStartedEvent(FlowEvent):
method_name: str method_name: str
state: dict[str, Any] | BaseModel state: dict[str, Any] | BaseModel
params: dict[str, Any] | None = None params: dict[str, Any] | None = None
type: str = "method_execution_started" type: Literal["method_execution_started"] = "method_execution_started"
class MethodExecutionFinishedEvent(FlowEvent): class MethodExecutionFinishedEvent(FlowEvent):
@@ -44,7 +44,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
method_name: str method_name: str
result: Any = None result: Any = None
state: dict[str, Any] | BaseModel state: dict[str, Any] | BaseModel
type: str = "method_execution_finished" type: Literal["method_execution_finished"] = "method_execution_finished"
class MethodExecutionFailedEvent(FlowEvent): class MethodExecutionFailedEvent(FlowEvent):
@@ -53,7 +53,7 @@ class MethodExecutionFailedEvent(FlowEvent):
flow_name: str flow_name: str
method_name: str method_name: str
error: Exception error: Exception
type: str = "method_execution_failed" type: Literal["method_execution_failed"] = "method_execution_failed"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -78,7 +78,7 @@ class MethodExecutionPausedEvent(FlowEvent):
flow_id: str flow_id: str
message: str message: str
emit: list[str] | None = None emit: list[str] | None = None
type: str = "method_execution_paused" type: Literal["method_execution_paused"] = "method_execution_paused"
class FlowFinishedEvent(FlowEvent): class FlowFinishedEvent(FlowEvent):
@@ -86,7 +86,7 @@ class FlowFinishedEvent(FlowEvent):
flow_name: str flow_name: str
result: Any | None = None result: Any | None = None
type: str = "flow_finished" type: Literal["flow_finished"] = "flow_finished"
state: dict[str, Any] | BaseModel state: dict[str, Any] | BaseModel
@@ -110,14 +110,14 @@ class FlowPausedEvent(FlowEvent):
state: dict[str, Any] | BaseModel state: dict[str, Any] | BaseModel
message: str message: str
emit: list[str] | None = None emit: list[str] | None = None
type: str = "flow_paused" type: Literal["flow_paused"] = "flow_paused"
class FlowPlotEvent(FlowEvent): class FlowPlotEvent(FlowEvent):
"""Event emitted when a flow plot is created""" """Event emitted when a flow plot is created"""
flow_name: str flow_name: str
type: str = "flow_plot" type: Literal["flow_plot"] = "flow_plot"
class HumanFeedbackRequestedEvent(FlowEvent): class HumanFeedbackRequestedEvent(FlowEvent):
@@ -138,7 +138,7 @@ class HumanFeedbackRequestedEvent(FlowEvent):
output: Any output: Any
message: str message: str
emit: list[str] | None = None emit: list[str] | None = None
type: str = "human_feedback_requested" type: Literal["human_feedback_requested"] = "human_feedback_requested"
class HumanFeedbackReceivedEvent(FlowEvent): class HumanFeedbackReceivedEvent(FlowEvent):
@@ -157,4 +157,4 @@ class HumanFeedbackReceivedEvent(FlowEvent):
method_name: str method_name: str
feedback: str feedback: str
outcome: str | None = None outcome: str | None = None
type: str = "human_feedback_received" type: Literal["human_feedback_received"] = "human_feedback_received"

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -20,14 +20,14 @@ class KnowledgeEventBase(BaseEvent):
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase): class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is started.""" """Event emitted when a knowledge retrieval is started."""
type: str = "knowledge_search_query_started" type: Literal["knowledge_search_query_started"] = "knowledge_search_query_started"
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase): class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is completed.""" """Event emitted when a knowledge retrieval is completed."""
query: str query: str
type: str = "knowledge_search_query_completed" type: Literal["knowledge_search_query_completed"] = "knowledge_search_query_completed"
retrieved_knowledge: str retrieved_knowledge: str
@@ -35,13 +35,13 @@ class KnowledgeQueryStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is started.""" """Event emitted when a knowledge query is started."""
task_prompt: str task_prompt: str
type: str = "knowledge_query_started" type: Literal["knowledge_query_started"] = "knowledge_query_started"
class KnowledgeQueryFailedEvent(KnowledgeEventBase): class KnowledgeQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query fails.""" """Event emitted when a knowledge query fails."""
type: str = "knowledge_query_failed" type: Literal["knowledge_query_failed"] = "knowledge_query_failed"
error: str error: str
@@ -49,12 +49,12 @@ class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is completed.""" """Event emitted when a knowledge query is completed."""
query: str query: str
type: str = "knowledge_query_completed" type: Literal["knowledge_query_completed"] = "knowledge_query_completed"
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase): class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge search query fails.""" """Event emitted when a knowledge search query fails."""
query: str query: str
type: str = "knowledge_search_query_failed" type: Literal["knowledge_search_query_failed"] = "knowledge_search_query_failed"
error: str error: str

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Literal
from pydantic import BaseModel from pydantic import BaseModel
@@ -42,7 +42,7 @@ class LLMCallStartedEvent(LLMEventBase):
multimodal content (text, images, etc.) multimodal content (text, images, etc.)
""" """
type: str = "llm_call_started" type: Literal["llm_call_started"] = "llm_call_started"
messages: str | list[dict[str, Any]] | None = None messages: str | list[dict[str, Any]] | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
callbacks: list[Any] | None = None callbacks: list[Any] | None = None
@@ -52,7 +52,7 @@ class LLMCallStartedEvent(LLMEventBase):
class LLMCallCompletedEvent(LLMEventBase): class LLMCallCompletedEvent(LLMEventBase):
"""Event emitted when a LLM call completes""" """Event emitted when a LLM call completes"""
type: str = "llm_call_completed" type: Literal["llm_call_completed"] = "llm_call_completed"
messages: str | list[dict[str, Any]] | None = None messages: str | list[dict[str, Any]] | None = None
response: Any response: Any
call_type: LLMCallType call_type: LLMCallType
@@ -62,7 +62,7 @@ class LLMCallFailedEvent(LLMEventBase):
"""Event emitted when a LLM call fails""" """Event emitted when a LLM call fails"""
error: str error: str
type: str = "llm_call_failed" type: Literal["llm_call_failed"] = "llm_call_failed"
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
@@ -80,7 +80,7 @@ class ToolCall(BaseModel):
class LLMStreamChunkEvent(LLMEventBase): class LLMStreamChunkEvent(LLMEventBase):
"""Event emitted when a streaming chunk is received""" """Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk" type: Literal["llm_stream_chunk"] = "llm_stream_chunk"
chunk: str chunk: str
tool_call: ToolCall | None = None tool_call: ToolCall | None = None
call_type: LLMCallType | None = None call_type: LLMCallType | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from inspect import getsource from inspect import getsource
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -27,7 +27,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried retry_count: The number of times the guardrail has been retried
""" """
type: str = "llm_guardrail_started" type: Literal["llm_guardrail_started"] = "llm_guardrail_started"
guardrail: str | Callable guardrail: str | Callable
retry_count: int retry_count: int
@@ -53,7 +53,7 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried retry_count: The number of times the guardrail has been retried
""" """
type: str = "llm_guardrail_completed" type: Literal["llm_guardrail_completed"] = "llm_guardrail_completed"
success: bool success: bool
result: Any result: Any
error: str | None = None error: str | None = None
@@ -68,6 +68,6 @@ class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried retry_count: The number of times the guardrail has been retried
""" """
type: str = "llm_guardrail_failed" type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed"
error: str error: str
retry_count: int retry_count: int

View File

@@ -1,6 +1,6 @@
"""Agent logging events that don't reference BaseAgent to avoid circular imports.""" """Agent logging events that don't reference BaseAgent to avoid circular imports."""
from typing import Any from typing import Any, Literal
from pydantic import ConfigDict from pydantic import ConfigDict
@@ -13,7 +13,7 @@ class AgentLogsStartedEvent(BaseEvent):
agent_role: str agent_role: str
task_description: str | None = None task_description: str | None = None
verbose: bool = False verbose: bool = False
type: str = "agent_logs_started" type: Literal["agent_logs_started"] = "agent_logs_started"
class AgentLogsExecutionEvent(BaseEvent): class AgentLogsExecutionEvent(BaseEvent):
@@ -22,6 +22,6 @@ class AgentLogsExecutionEvent(BaseEvent):
agent_role: str agent_role: str
formatted_answer: Any formatted_answer: Any
verbose: bool = False verbose: bool = False
type: str = "agent_logs_execution" type: Literal["agent_logs_execution"] = "agent_logs_execution"
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class MCPEvent(BaseEvent):
class MCPConnectionStartedEvent(MCPEvent): class MCPConnectionStartedEvent(MCPEvent):
"""Event emitted when starting to connect to an MCP server.""" """Event emitted when starting to connect to an MCP server."""
type: str = "mcp_connection_started" type: Literal["mcp_connection_started"] = "mcp_connection_started"
connect_timeout: int | None = None connect_timeout: int | None = None
is_reconnect: bool = ( is_reconnect: bool = (
False # True if this is a reconnection, False for first connection False # True if this is a reconnection, False for first connection
@@ -34,7 +34,7 @@ class MCPConnectionStartedEvent(MCPEvent):
class MCPConnectionCompletedEvent(MCPEvent): class MCPConnectionCompletedEvent(MCPEvent):
"""Event emitted when successfully connected to an MCP server.""" """Event emitted when successfully connected to an MCP server."""
type: str = "mcp_connection_completed" type: Literal["mcp_connection_completed"] = "mcp_connection_completed"
started_at: datetime | None = None started_at: datetime | None = None
completed_at: datetime | None = None completed_at: datetime | None = None
connection_duration_ms: float | None = None connection_duration_ms: float | None = None
@@ -46,7 +46,7 @@ class MCPConnectionCompletedEvent(MCPEvent):
class MCPConnectionFailedEvent(MCPEvent): class MCPConnectionFailedEvent(MCPEvent):
"""Event emitted when connection to an MCP server fails.""" """Event emitted when connection to an MCP server fails."""
type: str = "mcp_connection_failed" type: Literal["mcp_connection_failed"] = "mcp_connection_failed"
error: str error: str
error_type: str | None = None # "timeout", "authentication", "network", etc. error_type: str | None = None # "timeout", "authentication", "network", etc.
started_at: datetime | None = None started_at: datetime | None = None
@@ -56,7 +56,7 @@ class MCPConnectionFailedEvent(MCPEvent):
class MCPToolExecutionStartedEvent(MCPEvent): class MCPToolExecutionStartedEvent(MCPEvent):
"""Event emitted when starting to execute an MCP tool.""" """Event emitted when starting to execute an MCP tool."""
type: str = "mcp_tool_execution_started" type: Literal["mcp_tool_execution_started"] = "mcp_tool_execution_started"
tool_name: str tool_name: str
tool_args: dict[str, Any] | None = None tool_args: dict[str, Any] | None = None
@@ -64,7 +64,7 @@ class MCPToolExecutionStartedEvent(MCPEvent):
class MCPToolExecutionCompletedEvent(MCPEvent): class MCPToolExecutionCompletedEvent(MCPEvent):
"""Event emitted when MCP tool execution completes.""" """Event emitted when MCP tool execution completes."""
type: str = "mcp_tool_execution_completed" type: Literal["mcp_tool_execution_completed"] = "mcp_tool_execution_completed"
tool_name: str tool_name: str
tool_args: dict[str, Any] | None = None tool_args: dict[str, Any] | None = None
result: Any | None = None result: Any | None = None
@@ -76,7 +76,7 @@ class MCPToolExecutionCompletedEvent(MCPEvent):
class MCPToolExecutionFailedEvent(MCPEvent): class MCPToolExecutionFailedEvent(MCPEvent):
"""Event emitted when MCP tool execution fails.""" """Event emitted when MCP tool execution fails."""
type: str = "mcp_tool_execution_failed" type: Literal["mcp_tool_execution_failed"] = "mcp_tool_execution_failed"
tool_name: str tool_name: str
tool_args: dict[str, Any] | None = None tool_args: dict[str, Any] | None = None
error: str error: str

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -23,7 +23,7 @@ class MemoryBaseEvent(BaseEvent):
class MemoryQueryStartedEvent(MemoryBaseEvent): class MemoryQueryStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is started""" """Event emitted when a memory query is started"""
type: str = "memory_query_started" type: Literal["memory_query_started"] = "memory_query_started"
query: str query: str
limit: int limit: int
score_threshold: float | None = None score_threshold: float | None = None
@@ -32,7 +32,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
class MemoryQueryCompletedEvent(MemoryBaseEvent): class MemoryQueryCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is completed successfully""" """Event emitted when a memory query is completed successfully"""
type: str = "memory_query_completed" type: Literal["memory_query_completed"] = "memory_query_completed"
query: str query: str
results: Any results: Any
limit: int limit: int
@@ -43,7 +43,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
class MemoryQueryFailedEvent(MemoryBaseEvent): class MemoryQueryFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory query fails""" """Event emitted when a memory query fails"""
type: str = "memory_query_failed" type: Literal["memory_query_failed"] = "memory_query_failed"
query: str query: str
limit: int limit: int
score_threshold: float | None = None score_threshold: float | None = None
@@ -53,7 +53,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
class MemorySaveStartedEvent(MemoryBaseEvent): class MemorySaveStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is started""" """Event emitted when a memory save operation is started"""
type: str = "memory_save_started" type: Literal["memory_save_started"] = "memory_save_started"
value: str | None = None value: str | None = None
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
agent_role: str | None = None agent_role: str | None = None
@@ -62,7 +62,7 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
class MemorySaveCompletedEvent(MemoryBaseEvent): class MemorySaveCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is completed successfully""" """Event emitted when a memory save operation is completed successfully"""
type: str = "memory_save_completed" type: Literal["memory_save_completed"] = "memory_save_completed"
value: str value: str
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
agent_role: str | None = None agent_role: str | None = None
@@ -72,7 +72,7 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
class MemorySaveFailedEvent(MemoryBaseEvent): class MemorySaveFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation fails""" """Event emitted when a memory save operation fails"""
type: str = "memory_save_failed" type: Literal["memory_save_failed"] = "memory_save_failed"
value: str | None = None value: str | None = None
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
agent_role: str | None = None agent_role: str | None = None
@@ -82,14 +82,14 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
class MemoryRetrievalStartedEvent(MemoryBaseEvent): class MemoryRetrievalStartedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt starts""" """Event emitted when memory retrieval for a task prompt starts"""
type: str = "memory_retrieval_started" type: Literal["memory_retrieval_started"] = "memory_retrieval_started"
task_id: str | None = None task_id: str | None = None
class MemoryRetrievalCompletedEvent(MemoryBaseEvent): class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt completes successfully""" """Event emitted when memory retrieval for a task prompt completes successfully"""
type: str = "memory_retrieval_completed" type: Literal["memory_retrieval_completed"] = "memory_retrieval_completed"
task_id: str | None = None task_id: str | None = None
memory_content: str memory_content: str
retrieval_time_ms: float retrieval_time_ms: float
@@ -98,6 +98,6 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
class MemoryRetrievalFailedEvent(MemoryBaseEvent): class MemoryRetrievalFailedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt fails.""" """Event emitted when memory retrieval for a task prompt fails."""
type: str = "memory_retrieval_failed" type: Literal["memory_retrieval_failed"] = "memory_retrieval_failed"
task_id: str | None = None task_id: str | None = None
error: str error: str

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class ReasoningEvent(BaseEvent):
class AgentReasoningStartedEvent(ReasoningEvent): class AgentReasoningStartedEvent(ReasoningEvent):
"""Event emitted when an agent starts reasoning about a task.""" """Event emitted when an agent starts reasoning about a task."""
type: str = "agent_reasoning_started" type: Literal["agent_reasoning_started"] = "agent_reasoning_started"
agent_role: str agent_role: str
task_id: str task_id: str
@@ -32,7 +32,7 @@ class AgentReasoningStartedEvent(ReasoningEvent):
class AgentReasoningCompletedEvent(ReasoningEvent): class AgentReasoningCompletedEvent(ReasoningEvent):
"""Event emitted when an agent finishes its reasoning process.""" """Event emitted when an agent finishes its reasoning process."""
type: str = "agent_reasoning_completed" type: Literal["agent_reasoning_completed"] = "agent_reasoning_completed"
agent_role: str agent_role: str
task_id: str task_id: str
plan: str plan: str
@@ -42,7 +42,7 @@ class AgentReasoningCompletedEvent(ReasoningEvent):
class AgentReasoningFailedEvent(ReasoningEvent): class AgentReasoningFailedEvent(ReasoningEvent):
"""Event emitted when the reasoning process fails.""" """Event emitted when the reasoning process fails."""
type: str = "agent_reasoning_failed" type: Literal["agent_reasoning_failed"] = "agent_reasoning_failed"
agent_role: str agent_role: str
task_id: str task_id: str
error: str error: str

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -7,7 +7,7 @@ from crewai.tasks.task_output import TaskOutput
class TaskStartedEvent(BaseEvent): class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts""" """Event emitted when a task starts"""
type: str = "task_started" type: Literal["task_started"] = "task_started"
context: str | None context: str | None
task: Any | None = None task: Any | None = None
@@ -28,7 +28,7 @@ class TaskCompletedEvent(BaseEvent):
"""Event emitted when a task completes""" """Event emitted when a task completes"""
output: TaskOutput output: TaskOutput
type: str = "task_completed" type: Literal["task_completed"] = "task_completed"
task: Any | None = None task: Any | None = None
def __init__(self, **data): def __init__(self, **data):
@@ -48,7 +48,7 @@ class TaskFailedEvent(BaseEvent):
"""Event emitted when a task fails""" """Event emitted when a task fails"""
error: str error: str
type: str = "task_failed" type: Literal["task_failed"] = "task_failed"
task: Any | None = None task: Any | None = None
def __init__(self, **data): def __init__(self, **data):
@@ -67,7 +67,7 @@ class TaskFailedEvent(BaseEvent):
class TaskEvaluationEvent(BaseEvent): class TaskEvaluationEvent(BaseEvent):
"""Event emitted when a task evaluation is completed""" """Event emitted when a task evaluation is completed"""
type: str = "task_evaluation" type: Literal["task_evaluation"] = "task_evaluation"
evaluation_type: str evaluation_type: str
task: Any | None = None task: Any | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, Literal
from pydantic import ConfigDict from pydantic import ConfigDict
@@ -55,7 +55,7 @@ class ToolUsageEvent(BaseEvent):
class ToolUsageStartedEvent(ToolUsageEvent): class ToolUsageStartedEvent(ToolUsageEvent):
"""Event emitted when a tool execution is started""" """Event emitted when a tool execution is started"""
type: str = "tool_usage_started" type: Literal["tool_usage_started"] = "tool_usage_started"
class ToolUsageFinishedEvent(ToolUsageEvent): class ToolUsageFinishedEvent(ToolUsageEvent):
@@ -65,35 +65,35 @@ class ToolUsageFinishedEvent(ToolUsageEvent):
finished_at: datetime finished_at: datetime
from_cache: bool = False from_cache: bool = False
output: Any output: Any
type: str = "tool_usage_finished" type: Literal["tool_usage_finished"] = "tool_usage_finished"
class ToolUsageErrorEvent(ToolUsageEvent): class ToolUsageErrorEvent(ToolUsageEvent):
"""Event emitted when a tool execution encounters an error""" """Event emitted when a tool execution encounters an error"""
error: Any error: Any
type: str = "tool_usage_error" type: Literal["tool_usage_error"] = "tool_usage_error"
class ToolValidateInputErrorEvent(ToolUsageEvent): class ToolValidateInputErrorEvent(ToolUsageEvent):
"""Event emitted when a tool input validation encounters an error""" """Event emitted when a tool input validation encounters an error"""
error: Any error: Any
type: str = "tool_validate_input_error" type: Literal["tool_validate_input_error"] = "tool_validate_input_error"
class ToolSelectionErrorEvent(ToolUsageEvent): class ToolSelectionErrorEvent(ToolUsageEvent):
"""Event emitted when a tool selection encounters an error""" """Event emitted when a tool selection encounters an error"""
error: Any error: Any
type: str = "tool_selection_error" type: Literal["tool_selection_error"] = "tool_selection_error"
class ToolExecutionErrorEvent(BaseEvent): class ToolExecutionErrorEvent(BaseEvent):
"""Event emitted when a tool execution encounters an error""" """Event emitted when a tool execution encounters an error"""
error: Any error: Any
type: str = "tool_execution_error" type: Literal["tool_execution_error"] = "tool_execution_error"
tool_name: str tool_name: str
tool_args: dict[str, Any] tool_args: dict[str, Any]
tool_class: Callable tool_class: Callable

View File

@@ -7,7 +7,7 @@ for building event-driven workflows with conditional execution and routing.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
import copy import copy
import inspect import inspect
@@ -2382,7 +2382,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
message: str, message: str,
output: Any, output: Any,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
emit: Sequence[str] | None = None, emit: list[str] | None = None,
) -> str: ) -> str:
"""Request feedback from a human. """Request feedback from a human.
Args: Args:
@@ -2453,7 +2453,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _collapse_to_outcome( def _collapse_to_outcome(
self, self,
feedback: str, feedback: str,
outcomes: Sequence[str], outcomes: list[str],
llm: str | BaseLLM, llm: str | BaseLLM,
) -> str: ) -> str:
"""Collapse free-form feedback to a predefined outcome using LLM. """Collapse free-form feedback to a predefined outcome using LLM.

View File

@@ -53,7 +53,7 @@ Example (asynchronous with custom provider):
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
@@ -128,7 +128,7 @@ class HumanFeedbackConfig:
""" """
message: str message: str
emit: Sequence[str] | None = None emit: list[str] | None = None
llm: str | BaseLLM | None = None llm: str | BaseLLM | None = None
default_outcome: str | None = None default_outcome: str | None = None
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
@@ -154,7 +154,7 @@ class HumanFeedbackMethod(FlowMethod[Any, Any]):
def human_feedback( def human_feedback(
message: str, message: str,
emit: Sequence[str] | None = None, emit: list[str] | None = None,
llm: str | BaseLLM | None = None, llm: str | BaseLLM | None = None,
default_outcome: str | None = None, default_outcome: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,

View File

@@ -404,7 +404,7 @@ class BaseLLM(ABC):
from_agent: Agent | None = None, from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None, tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None, call_type: LLMCallType | None = None,
response_id: str | None = None response_id: str | None = None,
) -> None: ) -> None:
"""Emit stream chunk event. """Emit stream chunk event.
@@ -427,7 +427,7 @@ class BaseLLM(ABC):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
call_type=call_type, call_type=call_type,
response_id=response_id response_id=response_id,
), ),
) )
@@ -620,13 +620,11 @@ class BaseLLM(ABC):
try: try:
# Try to parse as JSON first # Try to parse as JSON first
if response.strip().startswith("{") or response.strip().startswith("["): if response.strip().startswith("{") or response.strip().startswith("["):
data = json.loads(response) return response_format.model_validate_json(response)
return response_format.model_validate(data)
json_match = _JSON_EXTRACTION_PATTERN.search(response) json_match = _JSON_EXTRACTION_PATTERN.search(response)
if json_match: if json_match:
data = json.loads(json_match.group()) return response_format.model_validate_json(json_match.group())
return response_format.model_validate(data)
raise ValueError("No JSON found in response") raise ValueError("No JSON found in response")

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import json import json
import logging import logging
@@ -538,7 +538,7 @@ class BedrockCompletion(BaseLLM):
self, self,
messages: list[LLMMessage], messages: list[LLMMessage],
body: BedrockConverseRequestBody, body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
@@ -1009,7 +1009,7 @@ class BedrockCompletion(BaseLLM):
self, self,
messages: list[LLMMessage], messages: list[LLMMessage],
body: BedrockConverseRequestBody, body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,

View File

@@ -1,6 +1,5 @@
"""Type definitions specific to ChromaDB implementation.""" """Type definitions specific to ChromaDB implementation."""
from collections.abc import Mapping
from typing import Any, NamedTuple from typing import Any, NamedTuple
from chromadb.api import AsyncClientAPI, ClientAPI from chromadb.api import AsyncClientAPI, ClientAPI
@@ -49,7 +48,7 @@ class PreparedDocuments(NamedTuple):
ids: list[str] ids: list[str]
texts: list[str] texts: list[str]
metadatas: list[Mapping[str, str | int | float | bool]] metadatas: list[dict[str, str | int | float | bool]]
class ExtractedSearchParams(NamedTuple): class ExtractedSearchParams(NamedTuple):

View File

@@ -1,6 +1,5 @@
"""Utility functions for ChromaDB client implementation.""" """Utility functions for ChromaDB client implementation."""
from collections.abc import Mapping
import hashlib import hashlib
import json import json
from typing import Literal, TypeGuard, cast from typing import Literal, TypeGuard, cast
@@ -66,7 +65,7 @@ def _prepare_documents_for_chromadb(
""" """
ids: list[str] = [] ids: list[str] = []
texts: list[str] = [] texts: list[str] = []
metadatas: list[Mapping[str, str | int | float | bool]] = [] metadatas: list[dict[str, str | int | float | bool]] = []
seen_ids: dict[str, int] = {} seen_ids: dict[str, int] = {}
try: try:
@@ -111,7 +110,7 @@ def _prepare_documents_for_chromadb(
def _create_batch_slice( def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]: ) -> tuple[list[str], list[str], list[dict[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents. """Create a batch slice from prepared documents.
Args: Args:

View File

@@ -1,6 +1,8 @@
"""Type definitions for the embeddings module.""" """Type definitions for the embeddings module."""
from typing import Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
@@ -29,7 +31,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
ProviderSpec: TypeAlias = ( ProviderSpec: TypeAlias = Annotated[
AzureProviderSpec AzureProviderSpec
| BedrockProviderSpec | BedrockProviderSpec
| CohereProviderSpec | CohereProviderSpec
@@ -47,8 +49,9 @@ ProviderSpec: TypeAlias = (
| Text2VecProviderSpec | Text2VecProviderSpec
| VertexAIProviderSpec | VertexAIProviderSpec
| VoyageAIProviderSpec | VoyageAIProviderSpec
| WatsonXProviderSpec | WatsonXProviderSpec,
) Field(discriminator="provider"),
]
AllowedEmbeddingProviders = Literal[ AllowedEmbeddingProviders = Literal[
"azure", "azure",

View File

@@ -1,6 +1,6 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems.""" """Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping from collections.abc import Callable
from typing import Any, TypeAlias from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@@ -19,8 +19,8 @@ class BaseRecord(TypedDict, total=False):
doc_id: str doc_id: str
content: Required[str] content: Required[str]
metadata: ( metadata: (
Mapping[str, str | int | float | bool] dict[str, str | int | float | bool]
| list[Mapping[str, str | int | float | bool]] | list[dict[str, str | int | float | bool]]
) )

View File

@@ -200,9 +200,12 @@ class CrewStructuredTool:
""" """
if isinstance(raw_args, str): if isinstance(raw_args, str):
try: try:
raw_args = json.loads(raw_args) validated_args = self.args_schema.model_validate_json(raw_args)
return validated_args.model_dump()
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
except Exception as e:
raise ValueError(f"Arguments validation failed: {e}") from e
try: try:
validated_args = self.args_schema.model_validate(raw_args) validated_args = self.args_schema.model_validate(raw_args)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable
import json import json
import re import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -98,7 +98,7 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
return tools_list return tools_list
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str: def get_tool_names(tools: list[CrewStructuredTool | BaseTool]) -> str:
"""Get the sanitized names of the tools. """Get the sanitized names of the tools.
Args: Args:
@@ -111,7 +111,7 @@ def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
def render_text_description_and_args( def render_text_description_and_args(
tools: Sequence[CrewStructuredTool | BaseTool], tools: list[CrewStructuredTool | BaseTool],
) -> str: ) -> str:
"""Render the tool name, description, and args in plain text. """Render the tool name, description, and args in plain text.
@@ -130,7 +130,7 @@ def render_text_description_and_args(
def convert_tools_to_openai_schema( def convert_tools_to_openai_schema(
tools: Sequence[BaseTool | CrewStructuredTool], tools: list[BaseTool | CrewStructuredTool],
) -> tuple[list[dict[str, Any]], dict[str, Callable[..., Any]]]: ) -> tuple[list[dict[str, Any]], dict[str, Callable[..., Any]]]:
"""Convert CrewAI tools to OpenAI function calling format. """Convert CrewAI tools to OpenAI function calling format.