mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-12 09:38:31 +00:00
fix: add type annotations and exclude tests from mypy
- Add type: ignore for mem0 import - Fix tool_usage.py cache_function None check - Change _execute_without_timeout return type to Any - Add type annotations to multiple functions: - add_sources() -> None - log() with proper parameter types - stop_rpm_counter() -> None - EventListener.__new__() -> Self - setup_listeners() -> None - Memory class __init__ methods -> None - TaskEvaluator.__init__() -> None - get_skipped_task_output() -> TaskOutput - Exclude tests directory from mypy checks in pyproject.toml - Update deprecated typing imports to use built-in types
This commit is contained in:
@@ -123,7 +123,7 @@ select = [
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
exclude = ["src/crewai/cli/templates"]
|
||||
exclude = ["src/crewai/cli/templates", "tests"]
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -530,7 +530,7 @@ class Agent(BaseAgent):
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,15 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -25,34 +41,21 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
from .types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -61,32 +64,30 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
||||
from .types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
knowledge_query_in_progress = False
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
@@ -105,7 +106,7 @@ class EventListener(BaseEventListener):
|
||||
|
||||
# ----------- CREW EVENTS -----------
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent):
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -18,17 +18,17 @@ class Knowledge(BaseModel):
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data,
|
||||
):
|
||||
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
def add_sources(self) -> None:
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -26,7 +26,9 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self, crew=None, embedder_config=None, storage=None, path=None
|
||||
) -> None:
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Dict, List
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(self, storage=None, path=None) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -28,7 +28,9 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self, crew=None, embedder_config=None, storage=None, path=None
|
||||
) -> None:
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
@@ -56,7 +58,7 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -38,7 +39,7 @@ class ConditionalTask(Task):
|
||||
"""
|
||||
return self.condition(context)
|
||||
|
||||
def get_skipped_task_output(self):
|
||||
def get_skipped_task_output(self) -> TaskOutput:
|
||||
return TaskOutput(
|
||||
description=self.description,
|
||||
raw="",
|
||||
|
||||
@@ -287,7 +287,8 @@ class ToolUsage:
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
if (
|
||||
hasattr(available_tool, "cache_function")
|
||||
available_tool
|
||||
and hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function
|
||||
):
|
||||
should_cache = available_tool.cache_function(
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities import Converter
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.utilities import Converter
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
from crewai.utilities.training_converter import TrainingConverter
|
||||
|
||||
@@ -13,23 +11,23 @@ class Entity(BaseModel):
|
||||
name: str = Field(description="The name of the entity.")
|
||||
type: str = Field(description="The type of the entity.")
|
||||
description: str = Field(description="Description of the entity.")
|
||||
relationships: List[str] = Field(description="Relationships of the entity.")
|
||||
relationships: list[str] = Field(description="Relationships of the entity.")
|
||||
|
||||
|
||||
class TaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
suggestions: list[str] = Field(
|
||||
description="Suggestions to improve future similar tasks."
|
||||
)
|
||||
quality: float = Field(
|
||||
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
|
||||
)
|
||||
entities: List[Entity] = Field(
|
||||
entities: list[Entity] = Field(
|
||||
description="Entities extracted from the task output."
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
suggestions: list[str] = Field(
|
||||
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
|
||||
)
|
||||
quality: float = Field(
|
||||
@@ -41,7 +39,7 @@ class TrainingTaskEvaluation(BaseModel):
|
||||
|
||||
|
||||
class TaskEvaluator:
|
||||
def __init__(self, original_agent):
|
||||
def __init__(self, original_agent) -> None:
|
||||
self.llm = original_agent.llm
|
||||
self.original_agent = original_agent
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ class Logger(BaseModel):
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
default_color: str = Field(default="bold_yellow")
|
||||
|
||||
def log(self, level, message, color=None):
|
||||
def log(self, level: str, message: str, color: str | None = None) -> None:
|
||||
if color is None:
|
||||
color = self.default_color
|
||||
if self.verbose:
|
||||
|
||||
@@ -50,7 +50,7 @@ class RPMController(BaseModel):
|
||||
else:
|
||||
return _check_and_increment()
|
||||
|
||||
def stop_rpm_counter(self):
|
||||
def stop_rpm_counter(self) -> None:
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
Reference in New Issue
Block a user