fix: add type annotations and exclude tests from mypy

- Add type: ignore for mem0 import
- Fix tool_usage.py cache_function None check
- Change _execute_without_timeout return type to Any
- Add type annotations to multiple functions:
  - add_sources() -> None
  - log() with proper parameter types
  - stop_rpm_counter() -> None
  - EventListener.__new__() -> Self
  - setup_listeners() -> None
  - Memory class __init__ methods -> None
  - TaskEvaluator.__init__() -> None
  - get_skipped_task_output() -> TaskOutput
- Exclude tests directory from mypy checks in pyproject.toml
- Update deprecated typing imports to use built-in types
This commit is contained in:
Greyson LaLonde
2025-09-04 11:11:59 -04:00
parent 0bab041531
commit 2ba48dd82a
13 changed files with 99 additions and 94 deletions

View File

@@ -123,7 +123,7 @@ select = [
[tool.mypy]
strict = true
exclude = ["src/crewai/cli/templates"]
exclude = ["src/crewai/cli/templates", "tests"]
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -530,7 +530,7 @@ class Agent(BaseAgent):
future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}")
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout.
Args:

View File

@@ -1,15 +1,31 @@
from __future__ import annotations
from io import StringIO
from typing import Any, Dict
from typing import Any
from pydantic import Field, PrivateAttr
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from typing_extensions import Self
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
@@ -25,34 +41,21 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from .listeners.memory_listener import MemoryListener
from .types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -61,32 +64,30 @@ from .types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
from .types.tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
)
from .listeners.memory_listener import MemoryListener
class EventListener(BaseEventListener):
_instance = None
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
execution_spans: dict[Task, Any] = Field(default_factory=dict)
next_chunk = 0
text_stream = StringIO()
knowledge_retrieval_in_progress = False
knowledge_query_in_progress = False
def __new__(cls):
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
@@ -105,7 +106,7 @@ class EventListener(BaseEventListener):
# ----------- CREW EVENTS -----------
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event: CrewKickoffStartedEvent):
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -18,17 +18,17 @@ class Knowledge(BaseModel):
embedder: Optional[Dict[str, Any]] = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
embedder: Optional[dict[str, Any]] = None
collection_name: Optional[str] = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
sources: list[BaseKnowledgeSource],
embedder: Optional[dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
**data,
):
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
)
return results
def add_sources(self):
def add_sources(self) -> None:
try:
for source in self.sources:
source.storage = self.storage

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -26,7 +26,9 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self, crew=None, embedder_config=None, storage=None, path=None
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -24,7 +24,7 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(self, storage=None, path=None) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
self,
task: str,
latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any, Optional
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -28,7 +28,9 @@ class ShortTermMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self, crew=None, embedder_config=None, storage=None, path=None
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
@@ -56,7 +58,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> None:
crewai_event_bus.emit(
self,

View File

@@ -2,7 +2,7 @@ import os
from collections import defaultdict
from typing import Any
from mem0 import Memory, MemoryClient
from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
from pydantic import Field
@@ -38,7 +39,7 @@ class ConditionalTask(Task):
"""
return self.condition(context)
def get_skipped_task_output(self):
def get_skipped_task_output(self) -> TaskOutput:
return TaskOutput(
description=self.description,
raw="",

View File

@@ -287,7 +287,8 @@ class ToolUsage:
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
available_tool
and hasattr(available_tool, "cache_function")
and available_tool.cache_function
):
should_cache = available_tool.cache_function(

View File

@@ -1,10 +1,8 @@
from typing import List
from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.utilities import Converter
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter
@@ -13,23 +11,23 @@ class Entity(BaseModel):
name: str = Field(description="The name of the entity.")
type: str = Field(description="The type of the entity.")
description: str = Field(description="Description of the entity.")
relationships: List[str] = Field(description="Relationships of the entity.")
relationships: list[str] = Field(description="Relationships of the entity.")
class TaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="Suggestions to improve future similar tasks."
)
quality: float = Field(
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
)
entities: List[Entity] = Field(
entities: list[Entity] = Field(
description="Entities extracted from the task output."
)
class TrainingTaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
)
quality: float = Field(
@@ -41,7 +39,7 @@ class TrainingTaskEvaluation(BaseModel):
class TaskEvaluator:
def __init__(self, original_agent):
def __init__(self, original_agent) -> None:
self.llm = original_agent.llm
self.original_agent = original_agent

View File

@@ -10,7 +10,7 @@ class Logger(BaseModel):
_printer: Printer = PrivateAttr(default_factory=Printer)
default_color: str = Field(default="bold_yellow")
def log(self, level, message, color=None):
def log(self, level: str, message: str, color: str | None = None) -> None:
if color is None:
color = self.default_color
if self.verbose:

View File

@@ -50,7 +50,7 @@ class RPMController(BaseModel):
else:
return _check_and_increment()
def stop_rpm_counter(self):
def stop_rpm_counter(self) -> None:
if self._timer:
self._timer.cancel()
self._timer = None