refactor: narrow Any-typed fields to concrete types across core models

This commit is contained in:
Greyson Lalonde
2026-03-07 16:52:55 -05:00
parent 3dc3f8bb52
commit a10ef6e28d
7 changed files with 123 additions and 106 deletions

View File

@@ -35,6 +35,7 @@ from typing_extensions import Self
if TYPE_CHECKING:
from crewai_files import FileInput
from opentelemetry.trace import Span
try:
from crewai_files import get_supported_content_types
@@ -65,8 +66,10 @@ from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
has_user_declined_tracing,
set_tracing_enabled,
should_enable_tracing,
should_suppress_tracing_messages,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
@@ -83,7 +86,10 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.process import Process
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
from crewai.security.fingerprint import Fingerprint
@@ -94,6 +100,8 @@ from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
from crewai.tools.base_tool import BaseTool
from crewai.tools.memory_tools import create_memory_tools
from crewai.types.callable import SerializableCallable
from crewai.types.streaming import CrewStreamingOutput
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
@@ -165,12 +173,12 @@ class Crew(FlowTrackable, BaseModel):
"""
__hash__ = object.__hash__
_execution_span: Any = PrivateAttr()
_execution_span: Span | None = PrivateAttr(default=None)
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
@@ -188,7 +196,7 @@ class Crew(FlowTrackable, BaseModel):
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool | Any = Field(
memory: bool | Memory | MemoryScope | MemorySlice = Field(
default=False,
description=(
"Enable crew memory. Pass True for default Memory(), "
@@ -203,23 +211,23 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: bool | None = Field(default=False)
step_callback: Any | None = Field(
step_callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after each step for all agents execution.",
)
task_callback: Any | None = Field(
task_callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after each task for all agents execution.",
)
@@ -262,7 +270,7 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
planning_llm: str | InstanceOf[BaseLLM] | None = Field(
default=None,
description=(
"Language model that will run the AgentPlanner if planning is True."
@@ -283,7 +291,7 @@ class Crew(FlowTrackable, BaseModel):
"knowledge object."
),
)
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
chat_llm: str | InstanceOf[BaseLLM] | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -356,12 +364,8 @@ class Crew(FlowTrackable, BaseModel):
def create_crew_memory(self) -> Crew:
"""Initialize unified memory, respecting crew embedder config."""
if self.memory is True:
from crewai.memory.unified_memory import Memory
embedder = None
if self.embedder is not None:
from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(self.embedder)
self._memory = Memory(embedder=embedder)
elif self.memory:
@@ -1411,7 +1415,7 @@ class Crew(FlowTrackable, BaseModel):
return tools
def _add_memory_tools(
self, tools: list[BaseTool], memory: Any
self, tools: list[BaseTool], memory: Memory | MemoryScope | MemorySlice
) -> list[BaseTool]:
"""Add recall and remember tools when memory is available.
@@ -1422,8 +1426,6 @@ class Crew(FlowTrackable, BaseModel):
Returns:
Updated list with memory tools added.
"""
from crewai.tools.memory_tools import create_memory_tools
return self._merge_tools(tools, create_memory_tools(memory))
def _add_file_tools(
@@ -2006,11 +2008,6 @@ class Crew(FlowTrackable, BaseModel):
@staticmethod
def _show_tracing_disabled_message() -> None:
"""Show a message when tracing is disabled."""
from crewai.events.listeners.tracing.utils import (
has_user_declined_tracing,
should_suppress_tracing_messages,
)
if should_suppress_tracing_messages():
return

View File

@@ -17,9 +17,12 @@ from collections.abc import (
ValuesView,
)
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
import copy
from datetime import datetime
import enum
import inspect
import json
import logging
import threading
from typing import (
@@ -49,6 +52,7 @@ from crewai.events.event_context import (
reset_last_event_id,
triggered_by_scope,
)
from crewai.events.event_listener import event_listener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
@@ -61,16 +65,27 @@ from crewai.events.listeners.tracing.utils import (
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
FlowInputReceivedEvent,
FlowInputRequestedEvent,
FlowPausedEvent,
FlowPlotEvent,
FlowStartedEvent,
HumanFeedbackReceivedEvent,
HumanFeedbackRequestedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionPausedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.async_feedback.providers import ConsoleProvider
from crewai.flow.async_feedback.types import HumanFeedbackPending
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
from crewai.flow.flow_config import flow_config
from crewai.flow.flow_context import (
current_flow_id,
current_flow_method_name,
current_flow_request_id,
)
from crewai.flow.flow_wrappers import (
FlowCondition,
FlowConditions,
@@ -80,6 +95,9 @@ from crewai.flow.flow_wrappers import (
SimpleFlowCondition,
StartMethod,
)
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.flow.input_provider import InputResponse
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import (
FlowExecutionData,
@@ -98,14 +116,18 @@ from crewai.flow.utils import (
is_flow_method_name,
is_simple_flow_condition,
)
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.i18n import get_i18n
if TYPE_CHECKING:
from crewai_files import FileInput
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.llms.base_llm import BaseLLM
from crewai.flow.input_provider import InputProvider
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
@@ -753,10 +775,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
name: str | None = None
tracing: bool | None = None
stream: bool = False
memory: Any = (
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
)
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
memory: Memory | MemoryScope | MemorySlice | None = None
input_provider: InputProvider | None = None
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
@@ -885,8 +905,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
if self.memory is None:
raise ValueError("No memory configured for this flow")
if isinstance(content, list):
from crewai.memory.unified_memory import Memory
if isinstance(content, list) and isinstance(self.memory, Memory):
return self.memory.remember_many(content, **kwargs)
if isinstance(content, list):
return [self.memory.remember(c, **kwargs) for c in content]
return self.memory.remember(content, **kwargs)
def extract_memories(self, content: str) -> list[str]:
@@ -1115,8 +1140,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
```
"""
if persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
persistence = SQLiteFlowPersistence()
# Load pending feedback context and state
@@ -1229,10 +1252,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If no pending feedback context exists
"""
from datetime import datetime
from crewai.flow.human_feedback import HumanFeedbackResult
if self._pending_feedback_context is None:
raise ValueError(
"No pending feedback context. Use from_pending() to restore a paused flow."
@@ -1315,13 +1334,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
except Exception as e:
# Check if flow was paused again for human feedback (loop case)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
state_data = (
@@ -1724,8 +1739,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
result_holder.append(result)
except Exception as e:
# HumanFeedbackPending is expected control flow, not an error
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
result_holder.append(e)
else:
@@ -1794,8 +1807,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
result_holder.append(result)
except Exception as e:
# HumanFeedbackPending is expected control flow, not an error
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
result_holder.append(e)
else:
@@ -1920,13 +1931,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
except Exception as e:
# Check if flow was paused for human feedback
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
state_data = (
@@ -2162,8 +2169,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Set method name in context so ask() can read it without
# stack inspection. Must happen before copy_context() so the
# value propagates into the thread pool for sync methods.
from crewai.flow.flow_context import current_flow_method_name
method_name_token = current_flow_method_name.set(method_name)
try:
if asyncio.iscoroutinefunction(method):
@@ -2171,8 +2176,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
finally:
@@ -2206,15 +2209,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
return result, finished_event_id
except Exception as e:
# Check if this is a HumanFeedbackPending exception (paused, not failed)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
e.context.method_name = method_name
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
# Emit paused event (not failed)
@@ -2646,8 +2645,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
except Exception as e:
# Don't log HumanFeedbackPending as an error - it's expected control flow
from crewai.flow.async_feedback.types import HumanFeedbackPending
if not isinstance(e, HumanFeedbackPending):
logger.error(f"Error executing listener {listener_name}: {e}")
raise
@@ -2665,9 +2662,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
An object implementing the ``InputProvider`` protocol.
"""
from crewai.flow.async_feedback.providers import ConsoleProvider
from crewai.flow.flow_config import flow_config
if self.input_provider is not None:
return self.input_provider
if flow_config.input_provider is not None:
@@ -2753,19 +2747,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
return topic
```
"""
from concurrent.futures import (
ThreadPoolExecutor,
TimeoutError as FuturesTimeoutError,
)
from datetime import datetime
from crewai.events.types.flow_events import (
FlowInputReceivedEvent,
FlowInputRequestedEvent,
)
from crewai.flow.flow_context import current_flow_method_name
from crewai.flow.input_provider import InputResponse
method_name = current_flow_method_name.get("unknown")
# Emit input requested event
@@ -2796,7 +2777,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
try:
raw = future.result(timeout=timeout)
except FuturesTimeoutError:
except TimeoutError:
future.cancel()
raw = None
finally:
@@ -2869,12 +2850,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
The human's feedback as a string. Empty string if no feedback provided.
"""
from crewai.events.event_listener import event_listener
from crewai.events.types.flow_events import (
HumanFeedbackReceivedEvent,
HumanFeedbackRequestedEvent,
)
# Emit feedback requested event
crewai_event_bus.emit(
self,
@@ -2948,18 +2923,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
One of the outcome strings that best matches the feedback intent.
"""
from typing import Literal
from pydantic import BaseModel, Field
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
from crewai.utilities.i18n import get_i18n
llm_instance: BaseLLMClass
llm_instance: BaseLLM
if isinstance(llm, str):
llm_instance = LLM(model=llm)
elif isinstance(llm, BaseLLMClass):
elif isinstance(llm, BaseLLM):
llm_instance = llm
else:
raise ValueError(f"Invalid llm type: {type(llm)}. Expected str or BaseLLM.")
@@ -2992,8 +2959,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
if isinstance(response, str):
import json
try:
parsed = json.loads(response)
return str(parsed.get("outcome", outcomes[0]))
@@ -3058,8 +3023,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
This method uses the centralized Rich console formatter for output
and the standard logging module for log level support.
"""
from crewai.events.event_listener import event_listener
event_listener.formatter.console.print(message, style=color)
if level == "info":
logger.info(message)

View File

@@ -83,6 +83,7 @@ if TYPE_CHECKING:
VoyageAIEmbeddingFunction,
)
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
from crewai.rag.embeddings.types import EmbedderConfig
T = TypeVar("T", bound=EmbeddingFunction[Any])
@@ -349,6 +350,10 @@ def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
@overload
def build_embedder(spec: EmbedderConfig) -> EmbeddingFunction[Any]: ...
def build_embedder(spec): # type: ignore[no-untyped-def]
"""Build an embedding function from either a provider spec or a provider instance.

View File

@@ -44,6 +44,7 @@ from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.types.callable import SerializableCallable
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.converter import Converter, convert_to_model
@@ -123,8 +124,9 @@ class Task(BaseModel):
description="Configuration for the agent",
default=None,
)
callback: Any | None = Field(
description="Callback to be executed after the task is completed.", default=None
callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after the task is completed.",
)
agent: BaseAgent | None = Field(
description="Agent responsible for execution the task.", default=None

View File

@@ -0,0 +1,50 @@
"""Serializable callable type for Pydantic models."""
from __future__ import annotations
from collections.abc import Callable
import importlib
from typing import Annotated, Any
from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
def _deserialize_callable(v: str | Callable[..., Any]) -> Callable[..., Any]:
"""Deserialize a dotted import path to a callable, or pass through if already callable."""
if isinstance(v, str):
module_path, _, name = v.rpartition(".")
if not module_path:
raise ValueError(f"Invalid callable path: {v!r} (expected 'module.name')")
module = importlib.import_module(module_path)
obj: Callable[..., Any] = getattr(module, name)
if not callable(obj):
raise ValueError(f"{v!r} resolved to {type(obj).__name__}, not a callable")
return obj
return v
def _serialize_callable(v: Callable[..., Any]) -> str:
"""Serialize a callable to its dotted import path."""
module = getattr(v, "__module__", None)
qualname = getattr(v, "__qualname__", None)
name = getattr(v, "__name__", None)
if not module or not name:
raise ValueError(
f"Cannot serialize {v!r}: missing __module__ or __name__. "
"Only top-level named functions are serializable."
)
if qualname and "<" in qualname:
raise ValueError(
f"Cannot serialize {v!r}: lambdas and nested functions are not serializable. "
"Use a top-level named function instead."
)
return f"{module}.{qualname or name}"
SerializableCallable = Annotated[
Callable[..., Any],
BeforeValidator(_deserialize_callable),
PlainSerializer(_serialize_callable, return_type=str),
WithJsonSchema({"type": "string"}),
]

View File

@@ -897,7 +897,7 @@ class TestCollapseToOutcomeJsonParsing:
"""Test that JSON string response from LLM is correctly parsed."""
flow = Flow()
with patch("crewai.llm.LLM") as MockLLM:
with patch("crewai.flow.flow.LLM") as MockLLM:
mock_llm = MagicMock()
# Simulate LLM returning JSON string (the bug we fixed)
mock_llm.call.return_value = '{"outcome": "approved"}'
@@ -915,7 +915,7 @@ class TestCollapseToOutcomeJsonParsing:
"""Test that plain string response is correctly matched."""
flow = Flow()
with patch("crewai.llm.LLM") as MockLLM:
with patch("crewai.flow.flow.LLM") as MockLLM:
mock_llm = MagicMock()
# Simulate LLM returning plain outcome string
mock_llm.call.return_value = "rejected"
@@ -933,7 +933,7 @@ class TestCollapseToOutcomeJsonParsing:
"""Test that invalid JSON falls back to string matching."""
flow = Flow()
with patch("crewai.llm.LLM") as MockLLM:
with patch("crewai.flow.flow.LLM") as MockLLM:
mock_llm = MagicMock()
# Invalid JSON that contains "approved"
mock_llm.call.return_value = "{invalid json but says approved"
@@ -951,7 +951,7 @@ class TestCollapseToOutcomeJsonParsing:
"""Test that LLM exception triggers fallback to simple prompting."""
flow = Flow()
with patch("crewai.llm.LLM") as MockLLM:
with patch("crewai.flow.flow.LLM") as MockLLM:
mock_llm = MagicMock()
# First call raises, second call succeeds (fallback)
mock_llm.call.side_effect = [

View File

@@ -36,7 +36,7 @@ from crewai.flow import Flow, start
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.memory.unified_memory import Memory
from crewai.process import Process
from crewai.project import CrewBase, agent, before_kickoff, crew, task
from crewai.task import Task
@@ -2618,9 +2618,9 @@ def test_memory_remember_called_after_task():
)
with patch.object(
crew._memory, "extract_memories", wraps=crew._memory.extract_memories
Memory, "extract_memories", wraps=crew._memory.extract_memories
) as extract_mock, patch.object(
crew._memory, "remember", wraps=crew._memory.remember
Memory, "remember", wraps=crew._memory.remember
) as remember_mock:
crew.kickoff()
@@ -4773,13 +4773,13 @@ def test_memory_remember_receives_task_content():
# Mock extract_memories to return fake memories and capture the raw input.
# No wraps= needed -- the test only checks what args it receives, not the output.
patch.object(
crew._memory, "extract_memories", return_value=["Fake memory."]
Memory, "extract_memories", return_value=["Fake memory."]
) as extract_mock,
# Mock recall to avoid LLM calls for query analysis (not in cassette).
patch.object(crew._memory, "recall", return_value=[]),
patch.object(Memory, "recall", return_value=[]),
# Mock remember_many to prevent the background save from triggering
# LLM calls (field resolution) that aren't in the cassette.
patch.object(crew._memory, "remember_many", return_value=[]),
patch.object(Memory, "remember_many", return_value=[]),
):
crew.kickoff()