feat: add llm_type and executor_type discriminators for checkpoint fidelity

Add type discriminator fields to BaseLLM subclasses and
BaseAgentExecutor subclasses so checkpoint deserialization restores
the correct provider class instead of always creating LLM/CrewAgentExecutor.
This commit is contained in:
Greyson LaLonde
2026-04-06 20:32:19 +08:00
parent 97866c68ce
commit 470af2f9e1
11 changed files with 43 additions and 6 deletions

View File

@@ -64,11 +64,26 @@ def _serialize_crew_ref(value: Any) -> str | None:
return str(value.id) if hasattr(value, "id") else str(value)
_LLM_TYPE_REGISTRY: dict[str, str] = {
"base": "crewai.llms.base_llm.BaseLLM",
"litellm": "crewai.llm.LLM",
"openai": "crewai.llms.providers.openai.completion.OpenAICompletion",
"anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion",
"azure": "crewai.llms.providers.azure.completion.AzureCompletion",
"bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion",
"gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion",
}
def _validate_llm_ref(value: Any) -> Any:
if isinstance(value, dict):
from crewai.llm import LLM
import importlib
return LLM(**value)
llm_type = value["llm_type"]
dotted = _LLM_TYPE_REGISTRY[llm_type]
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
return cls(**value)
return value
@@ -80,11 +95,22 @@ def _resolve_agent(value: Any, info: Any) -> Any:
return Agent.model_validate(value, context=getattr(info, "context", None))
_EXECUTOR_TYPE_REGISTRY: dict[str, str] = {
"base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor",
"crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor",
"experimental": "crewai.experimental.agent_executor.AgentExecutor",
}
def _validate_executor_ref(value: Any) -> Any:
if isinstance(value, dict):
from crewai.agents.crew_agent_executor import CrewAgentExecutor
import importlib
return CrewAgentExecutor.model_validate(value)
executor_type = value["executor_type"]
dotted = _EXECUTOR_TYPE_REGISTRY[executor_type]
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
return cls.model_validate(value)
return value

View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
class BaseAgentExecutor(BaseModel):
model_config = {"arbitrary_types_allowed": True}
executor_type: str = "base"
crew: Crew | None = Field(default=None, exclude=True)
agent: BaseAgent | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)

View File

@@ -96,6 +96,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
LLM interactions, tool execution, and feedback handling.
"""
executor_type: Literal["crew"] = "crew"
llm: Annotated[
BaseLLM | str | None,
BeforeValidator(_validate_llm_ref),

View File

@@ -170,6 +170,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
_skip_auto_memory: bool = True
executor_type: Literal["experimental"] = "experimental"
suppress_flow_events: bool = True # always suppress for executor
llm: BaseLLM = Field(exclude=True)
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)

View File

@@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
llm_type: Literal["litellm"] = "litellm"
completion_cost: float | None = None
timeout: float | int | None = None
top_p: float | None = None

View File

@@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
llm_type: str = "base"
model: str
temperature: float | None = None
api_key: str | None = None

View File

@@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM):
offering native tool use, streaming support, and proper message formatting.
"""
llm_type: Literal["anthropic"] = "anthropic"
model: str = "claude-3-5-sonnet-20241022"
timeout: float | None = None
max_retries: int = 2

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import os
from typing import Any, TypedDict
from typing import Any, Literal, TypedDict
from urllib.parse import urlparse
from pydantic import BaseModel, PrivateAttr, model_validator
@@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM):
offering native function calling, streaming support, and proper Azure authentication.
"""
llm_type: Literal["azure"] = "azure"
endpoint: str | None = None
api_version: str | None = None
timeout: float | None = None

View File

@@ -5,7 +5,7 @@ from contextlib import AsyncExitStack
import json
import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from pydantic import BaseModel, PrivateAttr, model_validator
from typing_extensions import Required
@@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM):
- Model-specific conversation format handling (e.g., Cohere requirements)
"""
llm_type: Literal["bedrock"] = "bedrock"
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None

View File

@@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM):
offering native function calling, streaming support, and proper Gemini formatting.
"""
llm_type: Literal["gemini"] = "gemini"
model: str = "gemini-2.0-flash-001"
project: str | None = None
location: str | None = None

View File

@@ -180,6 +180,8 @@ class OpenAICompletion(BaseLLM):
chain-of-thought without storing data on OpenAI servers.
"""
llm_type: Literal["openai"] = "openai"
BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = {
"web_search": "web_search_preview",
"file_search": "file_search",