mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 06:29:22 +00:00
feat: add llm_type and executor_type discriminators for checkpoint fidelity
Add type discriminator fields to BaseLLM subclasses and BaseAgentExecutor subclasses so checkpoint deserialization restores the correct provider class instead of always creating LLM/CrewAgentExecutor.
This commit is contained in:
@@ -64,11 +64,26 @@ def _serialize_crew_ref(value: Any) -> str | None:
|
||||
return str(value.id) if hasattr(value, "id") else str(value)
|
||||
|
||||
|
||||
_LLM_TYPE_REGISTRY: dict[str, str] = {
|
||||
"base": "crewai.llms.base_llm.BaseLLM",
|
||||
"litellm": "crewai.llm.LLM",
|
||||
"openai": "crewai.llms.providers.openai.completion.OpenAICompletion",
|
||||
"anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion",
|
||||
"azure": "crewai.llms.providers.azure.completion.AzureCompletion",
|
||||
"bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion",
|
||||
"gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion",
|
||||
}
|
||||
|
||||
|
||||
def _validate_llm_ref(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
from crewai.llm import LLM
|
||||
import importlib
|
||||
|
||||
return LLM(**value)
|
||||
llm_type = value["llm_type"]
|
||||
dotted = _LLM_TYPE_REGISTRY[llm_type]
|
||||
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||
return cls(**value)
|
||||
return value
|
||||
|
||||
|
||||
@@ -80,11 +95,22 @@ def _resolve_agent(value: Any, info: Any) -> Any:
|
||||
return Agent.model_validate(value, context=getattr(info, "context", None))
|
||||
|
||||
|
||||
_EXECUTOR_TYPE_REGISTRY: dict[str, str] = {
|
||||
"base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor",
|
||||
"crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor",
|
||||
"experimental": "crewai.experimental.agent_executor.AgentExecutor",
|
||||
}
|
||||
|
||||
|
||||
def _validate_executor_ref(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
import importlib
|
||||
|
||||
return CrewAgentExecutor.model_validate(value)
|
||||
executor_type = value["executor_type"]
|
||||
dotted = _EXECUTOR_TYPE_REGISTRY[executor_type]
|
||||
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||
return cls.model_validate(value)
|
||||
return value
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
class BaseAgentExecutor(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
executor_type: str = "base"
|
||||
crew: Crew | None = Field(default=None, exclude=True)
|
||||
agent: BaseAgent | None = Field(default=None, exclude=True)
|
||||
task: Task | None = Field(default=None, exclude=True)
|
||||
|
||||
@@ -96,6 +96,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
executor_type: Literal["crew"] = "crew"
|
||||
llm: Annotated[
|
||||
BaseLLM | str | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
|
||||
@@ -170,6 +170,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
|
||||
_skip_auto_memory: bool = True
|
||||
|
||||
executor_type: Literal["experimental"] = "experimental"
|
||||
suppress_flow_events: bool = True # always suppress for executor
|
||||
llm: BaseLLM = Field(exclude=True)
|
||||
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
|
||||
|
||||
@@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
llm_type: Literal["litellm"] = "litellm"
|
||||
completion_cost: float | None = None
|
||||
timeout: float | int | None = None
|
||||
top_p: float | None = None
|
||||
|
||||
@@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
llm_type: str = "base"
|
||||
model: str
|
||||
temperature: float | None = None
|
||||
api_key: str | None = None
|
||||
|
||||
@@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
llm_type: Literal["anthropic"] = "anthropic"
|
||||
model: str = "claude-3-5-sonnet-20241022"
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
@@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
"""
|
||||
|
||||
llm_type: Literal["azure"] = "azure"
|
||||
endpoint: str | None = None
|
||||
api_version: str | None = None
|
||||
timeout: float | None = None
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import AsyncExitStack
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from typing_extensions import Required
|
||||
@@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM):
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
llm_type: Literal["bedrock"] = "bedrock"
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
|
||||
@@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
llm_type: Literal["gemini"] = "gemini"
|
||||
model: str = "gemini-2.0-flash-001"
|
||||
project: str | None = None
|
||||
location: str | None = None
|
||||
|
||||
@@ -180,6 +180,8 @@ class OpenAICompletion(BaseLLM):
|
||||
chain-of-thought without storing data on OpenAI servers.
|
||||
"""
|
||||
|
||||
llm_type: Literal["openai"] = "openai"
|
||||
|
||||
BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = {
|
||||
"web_search": "web_search_preview",
|
||||
"file_search": "file_search",
|
||||
|
||||
Reference in New Issue
Block a user