From 470af2f9e135aff4a38936e48ba07db29050f454 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 6 Apr 2026 20:32:19 +0800 Subject: [PATCH] feat: add llm_type and executor_type discriminators for checkpoint fidelity Add type discriminator fields to BaseLLM subclasses and BaseAgentExecutor subclasses so checkpoint deserialization restores the correct provider class instead of always creating LLM/CrewAgentExecutor. --- .../crewai/agents/agent_builder/base_agent.py | 34 ++++++++++++++++--- .../agent_builder/base_agent_executor.py | 1 + .../src/crewai/agents/crew_agent_executor.py | 1 + .../src/crewai/experimental/agent_executor.py | 1 + lib/crewai/src/crewai/llm.py | 1 + lib/crewai/src/crewai/llms/base_llm.py | 1 + .../llms/providers/anthropic/completion.py | 1 + .../crewai/llms/providers/azure/completion.py | 3 +- .../llms/providers/bedrock/completion.py | 3 +- .../llms/providers/gemini/completion.py | 1 + .../llms/providers/openai/completion.py | 2 ++ 11 files changed, 43 insertions(+), 6 deletions(-) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 9ea223b46..9fc71a541 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -64,11 +64,26 @@ def _serialize_crew_ref(value: Any) -> str | None: return str(value.id) if hasattr(value, "id") else str(value) +_LLM_TYPE_REGISTRY: dict[str, str] = { + "base": "crewai.llms.base_llm.BaseLLM", + "litellm": "crewai.llm.LLM", + "openai": "crewai.llms.providers.openai.completion.OpenAICompletion", + "anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion", + "azure": "crewai.llms.providers.azure.completion.AzureCompletion", + "bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion", + "gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion", +} + + def _validate_llm_ref(value: Any) -> Any: if isinstance(value, dict): - from crewai.llm import LLM + import importlib - return LLM(**value) + llm_type = value["llm_type"] + dotted = _LLM_TYPE_REGISTRY[llm_type] + mod_path, cls_name = dotted.rsplit(".", 1) + cls = getattr(importlib.import_module(mod_path), cls_name) + return cls(**value) return value @@ -80,11 +95,22 @@ def _resolve_agent(value: Any, info: Any) -> Any: return Agent.model_validate(value, context=getattr(info, "context", None)) +_EXECUTOR_TYPE_REGISTRY: dict[str, str] = { + "base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor", + "crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor", + "experimental": "crewai.experimental.agent_executor.AgentExecutor", +} + + def _validate_executor_ref(value: Any) -> Any: if isinstance(value, dict): - from crewai.agents.crew_agent_executor import CrewAgentExecutor + import importlib - return CrewAgentExecutor.model_validate(value) + executor_type = value["executor_type"] + dotted = _EXECUTOR_TYPE_REGISTRY[executor_type] + mod_path, cls_name = dotted.rsplit(".", 1) + cls = getattr(importlib.import_module(mod_path), cls_name) + return cls.model_validate(value) return value diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py index 37028a63b..ad56807e4 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: class BaseAgentExecutor(BaseModel): model_config = {"arbitrary_types_allowed": True} + executor_type: str = "base" crew: Crew | None = Field(default=None, exclude=True) agent: BaseAgent | None = Field(default=None, exclude=True) task: Task | None = Field(default=None, exclude=True) diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 83b9b6de3..0a002ed8e 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -96,6 +96,7 @@ class CrewAgentExecutor(BaseAgentExecutor): LLM interactions, tool execution, and feedback handling. """ + executor_type: Literal["crew"] = "crew" llm: Annotated[ BaseLLM | str | None, BeforeValidator(_validate_llm_ref), diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 20584877a..067489c8e 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -170,6 +170,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor _skip_auto_memory: bool = True + executor_type: Literal["experimental"] = "experimental" suppress_flow_events: bool = True # always suppress for executor llm: BaseLLM = Field(exclude=True) prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 57079f63e..192fffd1a 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): + llm_type: Literal["litellm"] = "litellm" completion_cost: float | None = None timeout: float | int | None = None top_p: float | None = None diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index 9f00d1db8..fd3c8c45e 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + llm_type: str = "base" model: str temperature: float | None = None api_key: str | None = None diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index d710404bd..b6df34b94 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM): offering native tool use, streaming support, and proper message formatting. """ + llm_type: Literal["anthropic"] = "anthropic" model: str = "claude-3-5-sonnet-20241022" timeout: float | None = None max_retries: int = 2 diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 52bf05531..db7ab7e73 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import logging import os -from typing import Any, TypedDict +from typing import Any, Literal, TypedDict from urllib.parse import urlparse from pydantic import BaseModel, PrivateAttr, model_validator @@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM): offering native function calling, streaming support, and proper Azure authentication. """ + llm_type: Literal["azure"] = "azure" endpoint: str | None = None api_version: str | None = None timeout: float | None = None diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 6fcf3581d..c25c9bfec 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -5,7 +5,7 @@ from contextlib import AsyncExitStack import json import logging import os -from typing import TYPE_CHECKING, Any, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast from pydantic import BaseModel, PrivateAttr, model_validator from typing_extensions import Required @@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM): - Model-specific conversation format handling (e.g., Cohere requirements) """ + llm_type: Literal["bedrock"] = "bedrock" model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0" aws_access_key_id: str | None = None aws_secret_access_key: str | None = None diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index f790e22cf..c84f7f5fd 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM): offering native function calling, streaming support, and proper Gemini formatting. """ + llm_type: Literal["gemini"] = "gemini" model: str = "gemini-2.0-flash-001" project: str | None = None location: str | None = None diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index 89edf7ab3..ee84467a6 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -180,6 +180,8 @@ class OpenAICompletion(BaseLLM): chain-of-thought without storing data on OpenAI servers. """ + llm_type: Literal["openai"] = "openai" + BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = { "web_search": "web_search_preview", "file_search": "file_search",