mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
feat: convert executor/tools/prompts to BaseModel, enable checkpoint resume via kickoff()
This commit is contained in:
@@ -355,20 +355,19 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Crew:
|
||||
"""Restore a Crew from a checkpoint file.
|
||||
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint JSON file.
|
||||
|
||||
Returns:
|
||||
A Crew instance with state restored from the checkpoint.
|
||||
A Crew instance. Call kickoff() to resume from the last completed task.
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
# Parse as RuntimeState to handle discriminated union
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
@@ -378,9 +377,28 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if isinstance(entity, cls):
|
||||
if entity.execution_context is not None:
|
||||
apply_execution_context(entity.execution_context)
|
||||
entity._restore_runtime()
|
||||
return entity
|
||||
raise ValueError(f"No Crew found in checkpoint: {path}")
|
||||
|
||||
def _restore_runtime(self) -> None:
|
||||
"""Re-create runtime objects after restoring from a checkpoint."""
|
||||
for agent in self.agents:
|
||||
if isinstance(agent.llm, str):
|
||||
agent.llm = create_llm(agent.llm)
|
||||
agent.crew = self
|
||||
agent.agent_executor = None
|
||||
for task in self.tasks:
|
||||
if task.agent is not None:
|
||||
for agent in self.agents:
|
||||
if agent.role == task.agent.role:
|
||||
task.agent = agent
|
||||
break
|
||||
if self.checkpoint_inputs is not None:
|
||||
self._inputs = self.checkpoint_inputs
|
||||
if self.checkpoint_kickoff_event_id is not None:
|
||||
self._kickoff_event_id = self.checkpoint_kickoff_event_id
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||
@@ -1264,6 +1282,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager.crew = self
|
||||
|
||||
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
|
||||
for i, task in enumerate(tasks):
|
||||
if task.output is None:
|
||||
return i if i > 0 else None
|
||||
return None
|
||||
|
||||
def _execute_tasks(
|
||||
|
||||
@@ -7,14 +7,22 @@ import json
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Any, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
create_model,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
pass
|
||||
|
||||
|
||||
def build_schema_hint(args_schema: type[BaseModel]) -> str:
|
||||
@@ -42,49 +50,31 @@ class ToolUsageLimitExceededError(Exception):
|
||||
"""Exception raised when a tool has reached its maximum usage limit."""
|
||||
|
||||
|
||||
class CrewStructuredTool:
|
||||
class CrewStructuredTool(BaseModel):
|
||||
"""A structured tool that can operate on any number of inputs.
|
||||
|
||||
This tool intends to replace StructuredTool with a custom implementation
|
||||
that integrates better with CrewAI's ecosystem.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
args_schema: type[BaseModel],
|
||||
func: Callable[..., Any],
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
current_usage_count: int = 0,
|
||||
cache_function: Callable[..., bool] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the structured tool.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
name: The name of the tool
|
||||
description: A description of what the tool does
|
||||
args_schema: The pydantic model for the tool's arguments
|
||||
func: The function to run when the tool is called
|
||||
result_as_answer: Whether to return the output directly
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
current_usage_count: Current number of times this tool has been used.
|
||||
cache_function: Function to determine if the tool result should be cached.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.args_schema = args_schema
|
||||
self.func = func
|
||||
self._logger = Logger()
|
||||
self.result_as_answer = result_as_answer
|
||||
self.max_usage_count = max_usage_count
|
||||
self.current_usage_count = current_usage_count
|
||||
self.cache_function = cache_function
|
||||
self._original_tool: BaseTool | None = None
|
||||
name: str = Field(default="")
|
||||
description: str = Field(default="")
|
||||
args_schema: Any = Field(default=None)
|
||||
func: Any = Field(default=None, exclude=True)
|
||||
result_as_answer: bool = Field(default=False)
|
||||
max_usage_count: int | None = Field(default=None)
|
||||
current_usage_count: int = Field(default=0)
|
||||
cache_function: Any = Field(default=None, exclude=True)
|
||||
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||
_original_tool: Any = PrivateAttr(default=None)
|
||||
|
||||
# Validate the function signature matches the schema
|
||||
self._validate_function_signature()
|
||||
@model_validator(mode="after")
|
||||
def _validate_func(self) -> Self:
|
||||
if self.func is not None:
|
||||
self._validate_function_signature()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@@ -230,7 +220,7 @@ class CrewStructuredTool:
|
||||
|
||||
try:
|
||||
validated_args = self.args_schema.model_validate(raw_args)
|
||||
return validated_args.model_dump()
|
||||
return dict(validated_args.model_dump())
|
||||
except Exception as e:
|
||||
hint = build_schema_hint(self.args_schema)
|
||||
raise ValueError(f"Arguments validation failed: {e}{hint}") from e
|
||||
|
||||
@@ -2,25 +2,33 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
|
||||
|
||||
class StandardPromptResult(TypedDict):
|
||||
class StandardPromptResult(BaseModel):
|
||||
"""Result with only prompt field for standard mode."""
|
||||
|
||||
prompt: Annotated[str, "The generated prompt string"]
|
||||
prompt: str = Field(default="")
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return hasattr(self, key) and getattr(self, key) is not None
|
||||
|
||||
|
||||
class SystemPromptResult(StandardPromptResult):
|
||||
"""Result with system, user, and prompt fields for system prompt mode."""
|
||||
|
||||
system: Annotated[str, "The system prompt component"]
|
||||
user: Annotated[str, "The user prompt component"]
|
||||
system: str = Field(default="")
|
||||
user: str = Field(default="")
|
||||
|
||||
|
||||
COMPONENTS = Literal[
|
||||
|
||||
@@ -7,6 +7,8 @@ when available (for the litellm fallback path).
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
@@ -21,35 +23,24 @@ except ImportError:
|
||||
LITELLM_AVAILABLE = False
|
||||
|
||||
|
||||
# Create a base class that conditionally inherits from litellm's CustomLogger
|
||||
# when available, or from object when not available
|
||||
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
|
||||
_BaseClass: type = LiteLLMCustomLogger
|
||||
else:
|
||||
_BaseClass = object
|
||||
|
||||
|
||||
class TokenCalcHandler(_BaseClass): # type: ignore[misc]
|
||||
class TokenCalcHandler(BaseModel):
|
||||
"""Handler for calculating and tracking token usage in LLM calls.
|
||||
|
||||
This handler tracks prompt tokens, completion tokens, and cached tokens
|
||||
across requests. It works standalone and also integrates with litellm's
|
||||
logging system when litellm is installed (for the fallback path).
|
||||
|
||||
Attributes:
|
||||
token_cost_process: The token process tracker to accumulate usage metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
|
||||
"""Initialize the token calculation handler.
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
Args:
|
||||
token_cost_process: Optional token process tracker for accumulating metrics.
|
||||
"""
|
||||
# Only call super().__init__ if we have a real parent class with __init__
|
||||
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
|
||||
super().__init__(**kwargs)
|
||||
self.token_cost_process = token_cost_process
|
||||
token_cost_process: TokenProcess | None = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any
|
||||
) -> None:
|
||||
if token_cost_process is not None:
|
||||
kwargs["token_cost_process"] = token_cost_process
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
@@ -58,18 +49,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc]
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
"""Log successful LLM API call and track token usage.
|
||||
|
||||
This method has the same interface as litellm's CustomLogger.log_success_event()
|
||||
so it can be used as a litellm callback when litellm is installed, or called
|
||||
directly when litellm is not installed.
|
||||
|
||||
Args:
|
||||
kwargs: The arguments passed to the LLM call.
|
||||
response_obj: The response object from the LLM API.
|
||||
start_time: The timestamp when the call started.
|
||||
end_time: The timestamp when the call completed.
|
||||
"""
|
||||
"""Log successful LLM API call and track token usage."""
|
||||
if self.token_cost_process is None:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user