From 8354cdf061bafc21d2e63cb9a1e8e902dd6cfca5 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 4 Sep 2025 11:41:57 -0400 Subject: [PATCH] fix: add missing type annotations to fix mypy strict mode errors Added type annotations to 10 files to resolve mypy type checking errors: - Added return type annotations to methods missing them - Added parameter type annotations where missing - Fixed Optional type hints to be explicit - Removed redundant type cast in crew.py - Changed _execute_with_timeout return type from str to Any in agent.py Additional type errors remain in other files throughout the codebase. --- src/crewai/agent.py | 2 +- src/crewai/crew.py | 2 +- src/crewai/events/event_listener.py | 134 +++++++++++------- src/crewai/knowledge/knowledge.py | 4 +- src/crewai/memory/entity/entity_memory.py | 8 +- .../memory/long_term/long_term_memory.py | 12 +- .../memory/short_term/short_term_memory.py | 10 +- src/crewai/tasks/conditional_task.py | 8 +- .../utilities/evaluators/task_evaluator.py | 8 +- src/crewai/utilities/rpm_controller.py | 12 +- 10 files changed, 121 insertions(+), 79 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index c0e998ebf..510cc9fa6 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -497,7 +497,7 @@ class Agent(BaseAgent): ) return result - def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> str: + def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any: """Execute a task with a timeout. Args: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 2ed94b3ca..96f140085 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -908,7 +908,7 @@ class Crew(FlowTrackable, BaseModel): f"Skipping conditional task: {task.description}", color="yellow", ) - skipped_task_output = cast(TaskOutput, task.get_skipped_task_output()) + skipped_task_output = task.get_skipped_task_output() if not was_replayed: self._store_execution_log(task, skipped_task_output, task_index) diff --git a/src/crewai/events/event_listener.py b/src/crewai/events/event_listener.py index d9fb41875..f4df1407d 100644 --- a/src/crewai/events/event_listener.py +++ b/src/crewai/events/event_listener.py @@ -93,7 +93,7 @@ class EventListener(BaseEventListener): cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__(self) -> None: if not hasattr(self, "_initialized") or not self._initialized: super().__init__() self._telemetry = Telemetry() @@ -106,14 +106,14 @@ class EventListener(BaseEventListener): # ----------- CREW EVENTS ----------- - def setup_listeners(self, crewai_event_bus) -> None: + def setup_listeners(self, crewai_event_bus: Any) -> None: @crewai_event_bus.on(CrewKickoffStartedEvent) - def on_crew_started(source, event: CrewKickoffStartedEvent): + def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None: self.formatter.create_crew_tree(event.crew_name or "Crew", source.id) self._telemetry.crew_execution_span(source, event.inputs) @crewai_event_bus.on(CrewKickoffCompletedEvent) - def on_crew_completed(source, event: CrewKickoffCompletedEvent): + def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None: # Handle telemetry final_string_output = event.output.raw self._telemetry.end_crew(source, final_string_output) @@ -127,7 +127,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(CrewKickoffFailedEvent) - def on_crew_failed(source, event: CrewKickoffFailedEvent): + def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None: self.formatter.update_crew_tree( self.formatter.current_crew_tree, event.crew_name or "Crew", @@ -136,23 +136,25 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(CrewTrainStartedEvent) - def on_crew_train_started(source, event: CrewTrainStartedEvent): + def on_crew_train_started(source: Any, event: CrewTrainStartedEvent) -> None: self.formatter.handle_crew_train_started( event.crew_name or "Crew", str(event.timestamp) ) @crewai_event_bus.on(CrewTrainCompletedEvent) - def on_crew_train_completed(source, event: CrewTrainCompletedEvent): + def on_crew_train_completed( + source: Any, event: CrewTrainCompletedEvent + ) -> None: self.formatter.handle_crew_train_completed( event.crew_name or "Crew", str(event.timestamp) ) @crewai_event_bus.on(CrewTrainFailedEvent) - def on_crew_train_failed(source, event: CrewTrainFailedEvent): + def on_crew_train_failed(source: Any, event: CrewTrainFailedEvent) -> None: self.formatter.handle_crew_train_failed(event.crew_name or "Crew") @crewai_event_bus.on(CrewTestResultEvent) - def on_crew_test_result(source, event: CrewTestResultEvent): + def on_crew_test_result(source: Any, event: CrewTestResultEvent) -> None: self._telemetry.individual_test_result_span( source.crew, event.quality, @@ -163,7 +165,7 @@ class EventListener(BaseEventListener): # ----------- TASK EVENTS ----------- @crewai_event_bus.on(TaskStartedEvent) - def on_task_started(source, event: TaskStartedEvent): + def on_task_started(source: Any, event: TaskStartedEvent) -> None: span = self._telemetry.task_started(crew=source.agent.crew, task=source) self.execution_spans[source] = span # Pass both task ID and task name (if set) @@ -173,7 +175,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(TaskCompletedEvent) - def on_task_completed(source, event: TaskCompletedEvent): + def on_task_completed(source: Any, event: TaskCompletedEvent) -> None: # Handle telemetry span = self.execution_spans.get(source) if span: @@ -191,7 +193,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(TaskFailedEvent) - def on_task_failed(source, event: TaskFailedEvent): + def on_task_failed(source: Any, event: TaskFailedEvent) -> None: span = self.execution_spans.get(source) if span: if source.agent and source.agent.crew: @@ -211,7 +213,9 @@ class EventListener(BaseEventListener): # ----------- AGENT EVENTS ----------- @crewai_event_bus.on(AgentExecutionStartedEvent) - def on_agent_execution_started(source, event: AgentExecutionStartedEvent): + def on_agent_execution_started( + source: Any, event: AgentExecutionStartedEvent + ) -> None: self.formatter.create_agent_branch( self.formatter.current_task_branch, event.agent.role, @@ -219,7 +223,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(AgentExecutionCompletedEvent) - def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent): + def on_agent_execution_completed( + source: Any, event: AgentExecutionCompletedEvent + ) -> None: self.formatter.update_agent_status( self.formatter.current_agent_branch, event.agent.role, @@ -230,8 +236,8 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(LiteAgentExecutionStartedEvent) def on_lite_agent_execution_started( - source, event: LiteAgentExecutionStartedEvent - ): + source: Any, event: LiteAgentExecutionStartedEvent + ) -> None: """Handle LiteAgent execution started event.""" self.formatter.handle_lite_agent_execution( event.agent_info["role"], status="started", **event.agent_info @@ -239,15 +245,17 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(LiteAgentExecutionCompletedEvent) def on_lite_agent_execution_completed( - source, event: LiteAgentExecutionCompletedEvent - ): + source: Any, event: LiteAgentExecutionCompletedEvent + ) -> None: """Handle LiteAgent execution completed event.""" self.formatter.handle_lite_agent_execution( event.agent_info["role"], status="completed", **event.agent_info ) @crewai_event_bus.on(LiteAgentExecutionErrorEvent) - def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent): + def on_lite_agent_execution_error( + source: Any, event: LiteAgentExecutionErrorEvent + ) -> None: """Handle LiteAgent execution error event.""" self.formatter.handle_lite_agent_execution( event.agent_info["role"], @@ -259,25 +267,27 @@ class EventListener(BaseEventListener): # ----------- FLOW EVENTS ----------- @crewai_event_bus.on(FlowCreatedEvent) - def on_flow_created(source, event: FlowCreatedEvent): + def on_flow_created(source: Any, event: FlowCreatedEvent) -> None: self._telemetry.flow_creation_span(event.flow_name) self.formatter.create_flow_tree(event.flow_name, str(source.flow_id)) @crewai_event_bus.on(FlowStartedEvent) - def on_flow_started(source, event: FlowStartedEvent): + def on_flow_started(source: Any, event: FlowStartedEvent) -> None: self._telemetry.flow_execution_span( event.flow_name, list(source._methods.keys()) ) self.formatter.start_flow(event.flow_name, str(source.flow_id)) @crewai_event_bus.on(FlowFinishedEvent) - def on_flow_finished(source, event: FlowFinishedEvent): + def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None: self.formatter.update_flow_status( self.formatter.current_flow_tree, event.flow_name, source.flow_id ) @crewai_event_bus.on(MethodExecutionStartedEvent) - def on_method_execution_started(source, event: MethodExecutionStartedEvent): + def on_method_execution_started( + source: Any, event: MethodExecutionStartedEvent + ) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -286,7 +296,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(MethodExecutionFinishedEvent) - def on_method_execution_finished(source, event: MethodExecutionFinishedEvent): + def on_method_execution_finished( + source: Any, event: MethodExecutionFinishedEvent + ) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -295,7 +307,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(MethodExecutionFailedEvent) - def on_method_execution_failed(source, event: MethodExecutionFailedEvent): + def on_method_execution_failed( + source: Any, event: MethodExecutionFailedEvent + ) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -306,7 +320,7 @@ class EventListener(BaseEventListener): # ----------- TOOL USAGE EVENTS ----------- @crewai_event_bus.on(ToolUsageStartedEvent) - def on_tool_usage_started(source, event: ToolUsageStartedEvent): + def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None: if isinstance(source, LLM): self.formatter.handle_llm_tool_usage_started( event.tool_name, @@ -320,7 +334,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(ToolUsageFinishedEvent) - def on_tool_usage_finished(source, event: ToolUsageFinishedEvent): + def on_tool_usage_finished(source: Any, event: ToolUsageFinishedEvent) -> None: if isinstance(source, LLM): self.formatter.handle_llm_tool_usage_finished( event.tool_name, @@ -333,7 +347,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(ToolUsageErrorEvent) - def on_tool_usage_error(source, event: ToolUsageErrorEvent): + def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None: if isinstance(source, LLM): self.formatter.handle_llm_tool_usage_error( event.tool_name, @@ -350,7 +364,7 @@ class EventListener(BaseEventListener): # ----------- LLM EVENTS ----------- @crewai_event_bus.on(LLMCallStartedEvent) - def on_llm_call_started(source, event: LLMCallStartedEvent): + def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None: # Capture the returned tool branch and update the current_tool_branch reference thinking_branch = self.formatter.handle_llm_call_started( self.formatter.current_agent_branch, @@ -361,7 +375,7 @@ class EventListener(BaseEventListener): self.formatter.current_tool_branch = thinking_branch @crewai_event_bus.on(LLMCallCompletedEvent) - def on_llm_call_completed(source, event: LLMCallCompletedEvent): + def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None: self.formatter.handle_llm_call_completed( self.formatter.current_tool_branch, self.formatter.current_agent_branch, @@ -369,7 +383,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(LLMCallFailedEvent) - def on_llm_call_failed(source, event: LLMCallFailedEvent): + def on_llm_call_failed(source: Any, event: LLMCallFailedEvent) -> None: self.formatter.handle_llm_call_failed( self.formatter.current_tool_branch, event.error, @@ -377,7 +391,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(LLMStreamChunkEvent) - def on_llm_stream_chunk(source, event: LLMStreamChunkEvent): + def on_llm_stream_chunk(source: Any, event: LLMStreamChunkEvent) -> None: self.text_stream.write(event.chunk) self.text_stream.seek(self.next_chunk) @@ -390,7 +404,9 @@ class EventListener(BaseEventListener): # ----------- LLM GUARDRAIL EVENTS ----------- @crewai_event_bus.on(LLMGuardrailStartedEvent) - def on_llm_guardrail_started(source, event: LLMGuardrailStartedEvent): + def on_llm_guardrail_started( + source: Any, event: LLMGuardrailStartedEvent + ) -> None: guardrail_str = str(event.guardrail) guardrail_name = ( guardrail_str[:50] + "..." if len(guardrail_str) > 50 else guardrail_str @@ -399,13 +415,15 @@ class EventListener(BaseEventListener): self.formatter.handle_guardrail_started(guardrail_name, event.retry_count) @crewai_event_bus.on(LLMGuardrailCompletedEvent) - def on_llm_guardrail_completed(source, event: LLMGuardrailCompletedEvent): + def on_llm_guardrail_completed( + source: Any, event: LLMGuardrailCompletedEvent + ) -> None: self.formatter.handle_guardrail_completed( event.success, event.error, event.retry_count ) @crewai_event_bus.on(CrewTestStartedEvent) - def on_crew_test_started(source, event: CrewTestStartedEvent): + def on_crew_test_started(source: Any, event: CrewTestStartedEvent) -> None: cloned_crew = source.copy() self._telemetry.test_execution_span( cloned_crew, @@ -419,20 +437,20 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(CrewTestCompletedEvent) - def on_crew_test_completed(source, event: CrewTestCompletedEvent): + def on_crew_test_completed(source: Any, event: CrewTestCompletedEvent) -> None: self.formatter.handle_crew_test_completed( self.formatter.current_flow_tree, event.crew_name or "Crew", ) @crewai_event_bus.on(CrewTestFailedEvent) - def on_crew_test_failed(source, event: CrewTestFailedEvent): + def on_crew_test_failed(source: Any, event: CrewTestFailedEvent) -> None: self.formatter.handle_crew_test_failed(event.crew_name or "Crew") @crewai_event_bus.on(KnowledgeRetrievalStartedEvent) def on_knowledge_retrieval_started( - source, event: KnowledgeRetrievalStartedEvent - ): + source: Any, event: KnowledgeRetrievalStartedEvent + ) -> None: if self.knowledge_retrieval_in_progress: return @@ -445,8 +463,8 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(KnowledgeRetrievalCompletedEvent) def on_knowledge_retrieval_completed( - source, event: KnowledgeRetrievalCompletedEvent - ): + source: Any, event: KnowledgeRetrievalCompletedEvent + ) -> None: if not self.knowledge_retrieval_in_progress: return @@ -458,11 +476,15 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(KnowledgeQueryStartedEvent) - def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent): + def on_knowledge_query_started( + source: Any, event: KnowledgeQueryStartedEvent + ) -> None: pass @crewai_event_bus.on(KnowledgeQueryFailedEvent) - def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent): + def on_knowledge_query_failed( + source: Any, event: KnowledgeQueryFailedEvent + ) -> None: self.formatter.handle_knowledge_query_failed( self.formatter.current_agent_branch, event.error, @@ -470,13 +492,15 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(KnowledgeQueryCompletedEvent) - def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent): + def on_knowledge_query_completed( + source: Any, event: KnowledgeQueryCompletedEvent + ) -> None: pass @crewai_event_bus.on(KnowledgeSearchQueryFailedEvent) def on_knowledge_search_query_failed( - source, event: KnowledgeSearchQueryFailedEvent - ): + source: Any, event: KnowledgeSearchQueryFailedEvent + ) -> None: self.formatter.handle_knowledge_search_query_failed( self.formatter.current_agent_branch, event.error, @@ -486,7 +510,9 @@ class EventListener(BaseEventListener): # ----------- REASONING EVENTS ----------- @crewai_event_bus.on(AgentReasoningStartedEvent) - def on_agent_reasoning_started(source, event: AgentReasoningStartedEvent): + def on_agent_reasoning_started( + source: Any, event: AgentReasoningStartedEvent + ) -> None: self.formatter.handle_reasoning_started( self.formatter.current_agent_branch, event.attempt, @@ -494,7 +520,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(AgentReasoningCompletedEvent) - def on_agent_reasoning_completed(source, event: AgentReasoningCompletedEvent): + def on_agent_reasoning_completed( + source: Any, event: AgentReasoningCompletedEvent + ) -> None: self.formatter.handle_reasoning_completed( event.plan, event.ready, @@ -502,7 +530,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(AgentReasoningFailedEvent) - def on_agent_reasoning_failed(source, event: AgentReasoningFailedEvent): + def on_agent_reasoning_failed( + source: Any, event: AgentReasoningFailedEvent + ) -> None: self.formatter.handle_reasoning_failed( event.error, self.formatter.current_crew_tree, @@ -511,7 +541,7 @@ class EventListener(BaseEventListener): # ----------- AGENT LOGGING EVENTS ----------- @crewai_event_bus.on(AgentLogsStartedEvent) - def on_agent_logs_started(source, event: AgentLogsStartedEvent): + def on_agent_logs_started(source: Any, event: AgentLogsStartedEvent) -> None: self.formatter.handle_agent_logs_started( event.agent_role, event.task_description, @@ -519,7 +549,9 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(AgentLogsExecutionEvent) - def on_agent_logs_execution(source, event: AgentLogsExecutionEvent): + def on_agent_logs_execution( + source: Any, event: AgentLogsExecutionEvent + ) -> None: self.formatter.handle_agent_logs_execution( event.agent_role, event.formatted_answer, diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 86b0663d6..43fba9ad6 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -30,8 +30,8 @@ class Knowledge(BaseModel): sources: list[BaseKnowledgeSource], embedder: Optional[dict[str, Any]] = None, storage: Optional[KnowledgeStorage] = None, - **data, - ): + **data: Any, + ) -> None: super().__init__(**data) if storage: self.storage = storage diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 1664e8050..81d795b2a 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -27,7 +27,11 @@ class EntityMemory(Memory): _memory_provider: str | None = PrivateAttr() def __init__( - self, crew=None, embedder_config=None, storage=None, path=None + self, + crew: Any = None, + embedder_config: Any = None, + storage: Any = None, + path: Any = None, ) -> None: memory_provider = embedder_config.get("provider") if embedder_config else None if memory_provider == "mem0": @@ -157,7 +161,7 @@ class EntityMemory(Memory): query: str, limit: int = 3, score_threshold: float = 0.35, - ): + ) -> Any: crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index f627dc09c..82c6632cc 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -24,12 +24,12 @@ class LongTermMemory(Memory): LongTermMemoryItem instances. """ - def __init__(self, storage=None, path=None) -> None: + def __init__(self, storage: Any = None, path: Any = None) -> None: if not storage: storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() super().__init__(storage=storage) - def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" + def save(self, item: LongTermMemoryItem) -> None: crewai_event_bus.emit( self, event=MemorySaveStartedEvent( @@ -48,7 +48,7 @@ class LongTermMemory(Memory): metadata.update( {"agent": item.agent, "expected_output": item.expected_output} ) - self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage" + self.storage.save( task_description=item.task, score=metadata["quality"], metadata=metadata, @@ -80,11 +80,11 @@ class LongTermMemory(Memory): ) raise - def search( # type: ignore # signature of "search" incompatible with supertype "Memory" + def search( self, task: str, latest_n: int = 3, - ) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" + ) -> list[dict[str, Any]]: crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -98,7 +98,7 @@ class LongTermMemory(Memory): start_time = time.time() try: - results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" + results = self.storage.load(task, latest_n) crewai_event_bus.emit( self, diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 9de768b8f..f75e4be0b 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -29,7 +29,11 @@ class ShortTermMemory(Memory): _memory_provider: Optional[str] = PrivateAttr() def __init__( - self, crew=None, embedder_config=None, storage=None, path=None + self, + crew: Any = None, + embedder_config: Any = None, + storage: Any = None, + path: Any = None, ) -> None: memory_provider = embedder_config.get("provider") if embedder_config else None if memory_provider == "mem0": @@ -116,7 +120,7 @@ class ShortTermMemory(Memory): query: str, limit: int = 3, score_threshold: float = 0.35, - ): + ) -> Any: crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -133,7 +137,7 @@ class ShortTermMemory(Memory): try: results = self.storage.search( query=query, limit=limit, score_threshold=score_threshold - ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters + ) crewai_event_bus.emit( self, diff --git a/src/crewai/tasks/conditional_task.py b/src/crewai/tasks/conditional_task.py index 6258c9ccb..b199ef457 100644 --- a/src/crewai/tasks/conditional_task.py +++ b/src/crewai/tasks/conditional_task.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any +from typing import Any, Optional from pydantic import Field @@ -14,7 +14,7 @@ class ConditionalTask(Task): Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task. """ - condition: Callable[[TaskOutput], bool] = Field( + condition: Optional[Callable[[TaskOutput], bool]] = Field( default=None, description="Maximum number of retries for an agent to execute a task when an error occurs.", ) @@ -22,8 +22,8 @@ class ConditionalTask(Task): def __init__( self, condition: Callable[[Any], bool], - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.condition = condition diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index 70babc363..af9fdc99b 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from crewai.events.event_bus import crewai_event_bus @@ -39,11 +41,11 @@ class TrainingTaskEvaluation(BaseModel): class TaskEvaluator: - def __init__(self, original_agent) -> None: + def __init__(self, original_agent: Any) -> None: self.llm = original_agent.llm self.original_agent = original_agent - def evaluate(self, task, output) -> TaskEvaluation: + def evaluate(self, task: Any, output: Any) -> TaskEvaluation: crewai_event_bus.emit( self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task) ) @@ -74,7 +76,7 @@ class TaskEvaluator: return converter.to_pydantic() def evaluate_training_data( - self, training_data: dict, agent_id: str + self, training_data: dict[str, Any], agent_id: str ) -> TrainingTaskEvaluation: """ Evaluate the training data based on the llm output, human feedback, and improved output. diff --git a/src/crewai/utilities/rpm_controller.py b/src/crewai/utilities/rpm_controller.py index bbc4d5c56..946789b0c 100644 --- a/src/crewai/utilities/rpm_controller.py +++ b/src/crewai/utilities/rpm_controller.py @@ -20,18 +20,18 @@ class RPMController(BaseModel): _shutdown_flag: bool = PrivateAttr(default=False) @model_validator(mode="after") - def reset_counter(self): + def reset_counter(self) -> "RPMController": if self.max_rpm is not None: if not self._shutdown_flag: self._lock = threading.Lock() self._reset_request_count() return self - def check_or_wait(self): + def check_or_wait(self) -> bool: if self.max_rpm is None: return True - def _check_and_increment(): + def _check_and_increment() -> bool: if self.max_rpm is not None and self._current_rpm < self.max_rpm: self._current_rpm += 1 return True @@ -55,12 +55,12 @@ class RPMController(BaseModel): self._timer.cancel() self._timer = None - def _wait_for_next_minute(self): + def _wait_for_next_minute(self) -> None: time.sleep(60) self._current_rpm = 0 - def _reset_request_count(self): - def _reset(): + def _reset_request_count(self) -> None: + def _reset() -> None: self._current_rpm = 0 if not self._shutdown_flag: self._timer = threading.Timer(60.0, self._reset_request_count)