feat: convert executor/tools/prompts to BaseModel, enable checkpoint resume via kickoff()

This commit is contained in:
Greyson LaLonde
2026-04-03 12:03:11 +08:00
parent cf241d85e8
commit 2e1f882234
4 changed files with 79 additions and 80 deletions

View File

@@ -355,20 +355,19 @@ class Crew(FlowTrackable, BaseModel):
@classmethod
def from_checkpoint(cls, path: str) -> Crew:
"""Restore a Crew from a checkpoint file.
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
Args:
path: Path to a checkpoint JSON file.
Returns:
A Crew instance with state restored from the checkpoint.
A Crew instance. Call kickoff() to resume from the last completed task.
"""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
json_str = _Path(path).read_text()
# Parse as RuntimeState to handle discriminated union
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
@@ -378,9 +377,28 @@ class Crew(FlowTrackable, BaseModel):
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
entity._restore_runtime()
return entity
raise ValueError(f"No Crew found in checkpoint: {path}")
def _restore_runtime(self) -> None:
"""Re-create runtime objects after restoring from a checkpoint."""
for agent in self.agents:
if isinstance(agent.llm, str):
agent.llm = create_llm(agent.llm)
agent.crew = self
agent.agent_executor = None
for task in self.tasks:
if task.agent is not None:
for agent in self.agents:
if agent.role == task.agent.role:
task.agent = agent
break
if self.checkpoint_inputs is not None:
self._inputs = self.checkpoint_inputs
if self.checkpoint_kickoff_event_id is not None:
self._kickoff_event_id = self.checkpoint_kickoff_event_id
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
@@ -1264,6 +1282,9 @@ class Crew(FlowTrackable, BaseModel):
manager.crew = self
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
for i, task in enumerate(tasks):
if task.output is None:
return i if i > 0 else None
return None
def _execute_tasks(

View File

@@ -7,14 +7,22 @@ import json
import textwrap
from typing import TYPE_CHECKING, Any, get_type_hints
from pydantic import BaseModel, Field, create_model
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
create_model,
model_validator,
)
from typing_extensions import Self
from crewai.utilities.logger import Logger
from crewai.utilities.string_utils import sanitize_tool_name
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
pass
def build_schema_hint(args_schema: type[BaseModel]) -> str:
@@ -42,49 +50,31 @@ class ToolUsageLimitExceededError(Exception):
"""Exception raised when a tool has reached its maximum usage limit."""
class CrewStructuredTool:
class CrewStructuredTool(BaseModel):
"""A structured tool that can operate on any number of inputs.
This tool intends to replace StructuredTool with a custom implementation
that integrates better with CrewAI's ecosystem.
"""
def __init__(
self,
name: str,
description: str,
args_schema: type[BaseModel],
func: Callable[..., Any],
result_as_answer: bool = False,
max_usage_count: int | None = None,
current_usage_count: int = 0,
cache_function: Callable[..., bool] | None = None,
) -> None:
"""Initialize the structured tool.
model_config = ConfigDict(arbitrary_types_allowed=True)
Args:
name: The name of the tool
description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called
result_as_answer: Whether to return the output directly
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
current_usage_count: Current number of times this tool has been used.
cache_function: Function to determine if the tool result should be cached.
"""
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
self._logger = Logger()
self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
self.cache_function = cache_function
self._original_tool: BaseTool | None = None
name: str = Field(default="")
description: str = Field(default="")
args_schema: Any = Field(default=None)
func: Any = Field(default=None, exclude=True)
result_as_answer: bool = Field(default=False)
max_usage_count: int | None = Field(default=None)
current_usage_count: int = Field(default=0)
cache_function: Any = Field(default=None, exclude=True)
_logger: Logger = PrivateAttr(default_factory=Logger)
_original_tool: Any = PrivateAttr(default=None)
# Validate the function signature matches the schema
self._validate_function_signature()
@model_validator(mode="after")
def _validate_func(self) -> Self:
if self.func is not None:
self._validate_function_signature()
return self
@classmethod
def from_function(
@@ -230,7 +220,7 @@ class CrewStructuredTool:
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()
return dict(validated_args.model_dump())
except Exception as e:
hint = build_schema_hint(self.args_schema)
raise ValueError(f"Arguments validation failed: {e}{hint}") from e

View File

@@ -2,25 +2,33 @@
from __future__ import annotations
from typing import Annotated, Any, Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from crewai.utilities.i18n import I18N, get_i18n
class StandardPromptResult(TypedDict):
class StandardPromptResult(BaseModel):
"""Result with only prompt field for standard mode."""
prompt: Annotated[str, "The generated prompt string"]
prompt: str = Field(default="")
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
def __contains__(self, key: str) -> bool:
return hasattr(self, key) and getattr(self, key) is not None
class SystemPromptResult(StandardPromptResult):
"""Result with system, user, and prompt fields for system prompt mode."""
system: Annotated[str, "The system prompt component"]
user: Annotated[str, "The user prompt component"]
system: str = Field(default="")
user: str = Field(default="")
COMPONENTS = Literal[

View File

@@ -7,6 +7,8 @@ when available (for the litellm fallback path).
from typing import Any
from pydantic import BaseModel, Field
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.logger_utils import suppress_warnings
@@ -21,35 +23,24 @@ except ImportError:
LITELLM_AVAILABLE = False
# Create a base class that conditionally inherits from litellm's CustomLogger
# when available, or from object when not available
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
_BaseClass: type = LiteLLMCustomLogger
else:
_BaseClass = object
class TokenCalcHandler(_BaseClass): # type: ignore[misc]
class TokenCalcHandler(BaseModel):
"""Handler for calculating and tracking token usage in LLM calls.
This handler tracks prompt tokens, completion tokens, and cached tokens
across requests. It works standalone and also integrates with litellm's
logging system when litellm is installed (for the fallback path).
Attributes:
token_cost_process: The token process tracker to accumulate usage metrics.
"""
def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
"""Initialize the token calculation handler.
model_config = {"arbitrary_types_allowed": True}
Args:
token_cost_process: Optional token process tracker for accumulating metrics.
"""
# Only call super().__init__ if we have a real parent class with __init__
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
super().__init__(**kwargs)
self.token_cost_process = token_cost_process
token_cost_process: TokenProcess | None = Field(default=None)
def __init__(
self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any
) -> None:
if token_cost_process is not None:
kwargs["token_cost_process"] = token_cost_process
super().__init__(**kwargs)
def log_success_event(
self,
@@ -58,18 +49,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc]
start_time: float,
end_time: float,
) -> None:
"""Log successful LLM API call and track token usage.
This method has the same interface as litellm's CustomLogger.log_success_event()
so it can be used as a litellm callback when litellm is installed, or called
directly when litellm is not installed.
Args:
kwargs: The arguments passed to the LLM call.
response_obj: The response object from the LLM API.
start_time: The timestamp when the call started.
end_time: The timestamp when the call completed.
"""
"""Log successful LLM API call and track token usage."""
if self.token_cost_process is None:
return