fix: add type annotations and exclude tests from mypy

- Add type: ignore for mem0 import
- Fix tool_usage.py cache_function None check
- Change _execute_without_timeout return type to Any
- Add type annotations to multiple functions:
  - add_sources() -> None
  - log() with proper parameter types
  - stop_rpm_counter() -> None
  - EventListener.__new__() -> Self
  - setup_listeners() -> None
  - Memory class __init__ methods -> None
  - TaskEvaluator.__init__() -> None
  - get_skipped_task_output() -> TaskOutput
- Exclude tests directory from mypy checks in pyproject.toml
- Update deprecated typing imports to use built-in types
This commit is contained in:
Greyson LaLonde
2025-09-04 11:11:59 -04:00
parent 0bab041531
commit 2ba48dd82a
13 changed files with 99 additions and 94 deletions

View File

@@ -123,7 +123,7 @@ select = [
[tool.mypy] [tool.mypy]
strict = true strict = true
exclude = ["src/crewai/cli/templates"] exclude = ["src/crewai/cli/templates", "tests"]
[tool.bandit] [tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"] exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -530,7 +530,7 @@ class Agent(BaseAgent):
future.cancel() future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}") raise RuntimeError(f"Task execution failed: {str(e)}")
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str: def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout. """Execute a task without a timeout.
Args: Args:

View File

@@ -1,15 +1,31 @@
from __future__ import annotations from __future__ import annotations
from io import StringIO from io import StringIO
from typing import Any, Dict from typing import Any
from pydantic import Field, PrivateAttr from pydantic import Field, PrivateAttr
from crewai.llm import LLM from typing_extensions import Self
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.knowledge_events import ( from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent, KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent, KnowledgeQueryFailedEvent,
@@ -25,34 +41,21 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent, LLMStreamChunkEvent,
) )
from crewai.events.types.llm_guardrail_events import ( from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent, LLMGuardrailCompletedEvent,
) LLMGuardrailStartedEvent,
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
) )
from crewai.events.types.logging_events import ( from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
AgentLogsExecutionEvent, AgentLogsExecutionEvent,
AgentLogsStartedEvent,
) )
from crewai.events.types.crew_events import ( from crewai.events.utils.console_formatter import ConsoleFormatter
CrewKickoffCompletedEvent, from crewai.llm import LLM
CrewKickoffFailedEvent, from crewai.task import Task
CrewKickoffStartedEvent, from crewai.telemetry.telemetry import Telemetry
CrewTestCompletedEvent, from crewai.utilities import Logger
CrewTestFailedEvent, from crewai.utilities.constants import EMITTER_COLOR
CrewTestResultEvent,
CrewTestStartedEvent, from .listeners.memory_listener import MemoryListener
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from .types.flow_events import ( from .types.flow_events import (
FlowCreatedEvent, FlowCreatedEvent,
FlowFinishedEvent, FlowFinishedEvent,
@@ -61,32 +64,30 @@ from .types.flow_events import (
MethodExecutionFinishedEvent, MethodExecutionFinishedEvent,
MethodExecutionStartedEvent, MethodExecutionStartedEvent,
) )
from .types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
from .types.tool_usage_events import ( from .types.tool_usage_events import (
ToolUsageErrorEvent, ToolUsageErrorEvent,
ToolUsageFinishedEvent, ToolUsageFinishedEvent,
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
from .types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
)
from .listeners.memory_listener import MemoryListener
class EventListener(BaseEventListener): class EventListener(BaseEventListener):
_instance = None _instance = None
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry()) _telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR) logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: Dict[Task, Any] = Field(default_factory=dict) execution_spans: dict[Task, Any] = Field(default_factory=dict)
next_chunk = 0 next_chunk = 0
text_stream = StringIO() text_stream = StringIO()
knowledge_retrieval_in_progress = False knowledge_retrieval_in_progress = False
knowledge_query_in_progress = False knowledge_query_in_progress = False
def __new__(cls): def __new__(cls) -> Self:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
@@ -105,7 +106,7 @@ class EventListener(BaseEventListener):
# ----------- CREW EVENTS ----------- # ----------- CREW EVENTS -----------
def setup_listeners(self, crewai_event_bus): def setup_listeners(self, crewai_event_bus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent) @crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event: CrewKickoffStartedEvent): def on_crew_started(source, event: CrewKickoffStartedEvent):
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id) self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@@ -18,17 +18,17 @@ class Knowledge(BaseModel):
embedder: Optional[Dict[str, Any]] = None embedder: Optional[Dict[str, Any]] = None
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: Optional[dict[str, Any]] = None
collection_name: Optional[str] = None collection_name: Optional[str] = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: list[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None, embedder: Optional[dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None, storage: Optional[KnowledgeStorage] = None,
**data, **data,
): ):
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()
def query( def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Query across all knowledge sources to find the most relevant information. Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks. Returns the top_k most relevant chunks.
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
) )
return results return results
def add_sources(self): def add_sources(self) -> None:
try: try:
for source in self.sources: for source in self.sources:
source.storage = self.storage source.storage = self.storage

View File

@@ -1,20 +1,20 @@
from typing import Any
import time import time
from typing import Any
from pydantic import PrivateAttr from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory): class EntityMemory(Memory):
@@ -26,7 +26,9 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr() _memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(
self, crew=None, embedder_config=None, storage=None, path=None
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0": if memory_provider == "mem0":
try: try:

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time import time
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent, MemoryQueryCompletedEvent,
MemoryQueryFailedEvent, MemoryQueryFailedEvent,
MemorySaveStartedEvent, MemoryQueryStartedEvent,
MemorySaveCompletedEvent, MemorySaveCompletedEvent,
MemorySaveFailedEvent, MemorySaveFailedEvent,
MemorySaveStartedEvent,
) )
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -24,7 +24,7 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances. LongTermMemoryItem instances.
""" """
def __init__(self, storage=None, path=None): def __init__(self, storage=None, path=None) -> None:
if not storage: if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage) super().__init__(storage=storage)
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
self, self,
task: str, task: str,
latest_n: int = 3, latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" ) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryQueryStartedEvent( event=MemoryQueryStartedEvent(

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time import time
from typing import Any, Optional
from pydantic import PrivateAttr from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory): class ShortTermMemory(Memory):
@@ -28,7 +28,9 @@ class ShortTermMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(
self, crew=None, embedder_config=None, storage=None, path=None
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0": if memory_provider == "mem0":
try: try:
@@ -56,7 +58,7 @@ class ShortTermMemory(Memory):
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,

View File

@@ -2,7 +2,7 @@ import os
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name from crewai.utilities.chromadb import sanitize_collection_name

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable from collections.abc import Callable
from typing import Any
from pydantic import Field from pydantic import Field
@@ -38,7 +39,7 @@ class ConditionalTask(Task):
""" """
return self.condition(context) return self.condition(context)
def get_skipped_task_output(self): def get_skipped_task_output(self) -> TaskOutput:
return TaskOutput( return TaskOutput(
description=self.description, description=self.description,
raw="", raw="",

View File

@@ -287,7 +287,8 @@ class ToolUsage:
if self.tools_handler: if self.tools_handler:
should_cache = True should_cache = True
if ( if (
hasattr(available_tool, "cache_function") available_tool
and hasattr(available_tool, "cache_function")
and available_tool.cache_function and available_tool.cache_function
): ):
should_cache = available_tool.cache_function( should_cache = available_tool.cache_function(

View File

@@ -1,10 +1,8 @@
from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.utilities import Converter
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter from crewai.utilities.training_converter import TrainingConverter
@@ -13,23 +11,23 @@ class Entity(BaseModel):
name: str = Field(description="The name of the entity.") name: str = Field(description="The name of the entity.")
type: str = Field(description="The type of the entity.") type: str = Field(description="The type of the entity.")
description: str = Field(description="Description of the entity.") description: str = Field(description="Description of the entity.")
relationships: List[str] = Field(description="Relationships of the entity.") relationships: list[str] = Field(description="Relationships of the entity.")
class TaskEvaluation(BaseModel): class TaskEvaluation(BaseModel):
suggestions: List[str] = Field( suggestions: list[str] = Field(
description="Suggestions to improve future similar tasks." description="Suggestions to improve future similar tasks."
) )
quality: float = Field( quality: float = Field(
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task." description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
) )
entities: List[Entity] = Field( entities: list[Entity] = Field(
description="Entities extracted from the task output." description="Entities extracted from the task output."
) )
class TrainingTaskEvaluation(BaseModel): class TrainingTaskEvaluation(BaseModel):
suggestions: List[str] = Field( suggestions: list[str] = Field(
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions." description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
) )
quality: float = Field( quality: float = Field(
@@ -41,7 +39,7 @@ class TrainingTaskEvaluation(BaseModel):
class TaskEvaluator: class TaskEvaluator:
def __init__(self, original_agent): def __init__(self, original_agent) -> None:
self.llm = original_agent.llm self.llm = original_agent.llm
self.original_agent = original_agent self.original_agent = original_agent

View File

@@ -10,7 +10,7 @@ class Logger(BaseModel):
_printer: Printer = PrivateAttr(default_factory=Printer) _printer: Printer = PrivateAttr(default_factory=Printer)
default_color: str = Field(default="bold_yellow") default_color: str = Field(default="bold_yellow")
def log(self, level, message, color=None): def log(self, level: str, message: str, color: str | None = None) -> None:
if color is None: if color is None:
color = self.default_color color = self.default_color
if self.verbose: if self.verbose:

View File

@@ -50,7 +50,7 @@ class RPMController(BaseModel):
else: else:
return _check_and_increment() return _check_and_increment()
def stop_rpm_counter(self): def stop_rpm_counter(self) -> None:
if self._timer: if self._timer:
self._timer.cancel() self._timer.cancel()
self._timer = None self._timer = None