diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 37b9b1de9..162462fe4 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -355,20 +355,19 @@ class Crew(FlowTrackable, BaseModel): @classmethod def from_checkpoint(cls, path: str) -> Crew: - """Restore a Crew from a checkpoint file. + """Restore a Crew from a checkpoint file, ready to resume via kickoff(). Args: path: Path to a checkpoint JSON file. Returns: - A Crew instance with state restored from the checkpoint. + A Crew instance. Call kickoff() to resume from the last completed task. """ from pathlib import Path as _Path from crewai.context import apply_execution_context json_str = _Path(path).read_text() - # Parse as RuntimeState to handle discriminated union from crewai import RuntimeState state = RuntimeState.model_validate_json( @@ -378,9 +377,28 @@ class Crew(FlowTrackable, BaseModel): if isinstance(entity, cls): if entity.execution_context is not None: apply_execution_context(entity.execution_context) + entity._restore_runtime() return entity raise ValueError(f"No Crew found in checkpoint: {path}") + def _restore_runtime(self) -> None: + """Re-create runtime objects after restoring from a checkpoint.""" + for agent in self.agents: + if isinstance(agent.llm, str): + agent.llm = create_llm(agent.llm) + agent.crew = self + agent.agent_executor = None + for task in self.tasks: + if task.agent is not None: + for agent in self.agents: + if agent.role == task.agent.role: + task.agent = agent + break + if self.checkpoint_inputs is not None: + self._inputs = self.checkpoint_inputs + if self.checkpoint_kickoff_event_id is not None: + self._kickoff_event_id = self.checkpoint_kickoff_event_id + @field_validator("id", mode="before") @classmethod def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None: @@ -1264,6 +1282,9 @@ class Crew(FlowTrackable, BaseModel): manager.crew = self def _get_execution_start_index(self, tasks: list[Task]) -> int | None: + for i, task in enumerate(tasks): + if task.output is None: + return i if i > 0 else None return None def _execute_tasks( diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 60a457f3b..766c75f20 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -7,14 +7,22 @@ import json import textwrap from typing import TYPE_CHECKING, Any, get_type_hints -from pydantic import BaseModel, Field, create_model +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + create_model, + model_validator, +) +from typing_extensions import Self from crewai.utilities.logger import Logger from crewai.utilities.string_utils import sanitize_tool_name if TYPE_CHECKING: - from crewai.tools.base_tool import BaseTool + pass def build_schema_hint(args_schema: type[BaseModel]) -> str: @@ -42,49 +50,31 @@ class ToolUsageLimitExceededError(Exception): """Exception raised when a tool has reached its maximum usage limit.""" -class CrewStructuredTool: +class CrewStructuredTool(BaseModel): """A structured tool that can operate on any number of inputs. This tool intends to replace StructuredTool with a custom implementation that integrates better with CrewAI's ecosystem. """ - def __init__( - self, - name: str, - description: str, - args_schema: type[BaseModel], - func: Callable[..., Any], - result_as_answer: bool = False, - max_usage_count: int | None = None, - current_usage_count: int = 0, - cache_function: Callable[..., bool] | None = None, - ) -> None: - """Initialize the structured tool. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - name: The name of the tool - description: A description of what the tool does - args_schema: The pydantic model for the tool's arguments - func: The function to run when the tool is called - result_as_answer: Whether to return the output directly - max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. - current_usage_count: Current number of times this tool has been used. - cache_function: Function to determine if the tool result should be cached. - """ - self.name = name - self.description = description - self.args_schema = args_schema - self.func = func - self._logger = Logger() - self.result_as_answer = result_as_answer - self.max_usage_count = max_usage_count - self.current_usage_count = current_usage_count - self.cache_function = cache_function - self._original_tool: BaseTool | None = None + name: str = Field(default="") + description: str = Field(default="") + args_schema: Any = Field(default=None) + func: Any = Field(default=None, exclude=True) + result_as_answer: bool = Field(default=False) + max_usage_count: int | None = Field(default=None) + current_usage_count: int = Field(default=0) + cache_function: Any = Field(default=None, exclude=True) + _logger: Logger = PrivateAttr(default_factory=Logger) + _original_tool: Any = PrivateAttr(default=None) - # Validate the function signature matches the schema - self._validate_function_signature() + @model_validator(mode="after") + def _validate_func(self) -> Self: + if self.func is not None: + self._validate_function_signature() + return self @classmethod def from_function( @@ -230,7 +220,7 @@ class CrewStructuredTool: try: validated_args = self.args_schema.model_validate(raw_args) - return validated_args.model_dump() + return dict(validated_args.model_dump()) except Exception as e: hint = build_schema_hint(self.args_schema) raise ValueError(f"Arguments validation failed: {e}{hint}") from e diff --git a/lib/crewai/src/crewai/utilities/prompts.py b/lib/crewai/src/crewai/utilities/prompts.py index e88a9708a..821623b89 100644 --- a/lib/crewai/src/crewai/utilities/prompts.py +++ b/lib/crewai/src/crewai/utilities/prompts.py @@ -2,25 +2,33 @@ from __future__ import annotations -from typing import Annotated, Any, Literal +from typing import Any, Literal from pydantic import BaseModel, Field -from typing_extensions import TypedDict from crewai.utilities.i18n import I18N, get_i18n -class StandardPromptResult(TypedDict): +class StandardPromptResult(BaseModel): """Result with only prompt field for standard mode.""" - prompt: Annotated[str, "The generated prompt string"] + prompt: str = Field(default="") + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) and getattr(self, key) is not None class SystemPromptResult(StandardPromptResult): """Result with system, user, and prompt fields for system prompt mode.""" - system: Annotated[str, "The system prompt component"] - user: Annotated[str, "The user prompt component"] + system: str = Field(default="") + user: str = Field(default="") COMPONENTS = Literal[ diff --git a/lib/crewai/src/crewai/utilities/token_counter_callback.py b/lib/crewai/src/crewai/utilities/token_counter_callback.py index 9c3a5cc5f..64a0ab299 100644 --- a/lib/crewai/src/crewai/utilities/token_counter_callback.py +++ b/lib/crewai/src/crewai/utilities/token_counter_callback.py @@ -7,6 +7,8 @@ when available (for the litellm fallback path). from typing import Any +from pydantic import BaseModel, Field + from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.utilities.logger_utils import suppress_warnings @@ -21,35 +23,24 @@ except ImportError: LITELLM_AVAILABLE = False -# Create a base class that conditionally inherits from litellm's CustomLogger -# when available, or from object when not available -if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None: - _BaseClass: type = LiteLLMCustomLogger -else: - _BaseClass = object - - -class TokenCalcHandler(_BaseClass): # type: ignore[misc] +class TokenCalcHandler(BaseModel): """Handler for calculating and tracking token usage in LLM calls. This handler tracks prompt tokens, completion tokens, and cached tokens across requests. It works standalone and also integrates with litellm's logging system when litellm is installed (for the fallback path). - - Attributes: - token_cost_process: The token process tracker to accumulate usage metrics. """ - def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None: - """Initialize the token calculation handler. + model_config = {"arbitrary_types_allowed": True} - Args: - token_cost_process: Optional token process tracker for accumulating metrics. - """ - # Only call super().__init__ if we have a real parent class with __init__ - if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None: - super().__init__(**kwargs) - self.token_cost_process = token_cost_process + token_cost_process: TokenProcess | None = Field(default=None) + + def __init__( + self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any + ) -> None: + if token_cost_process is not None: + kwargs["token_cost_process"] = token_cost_process + super().__init__(**kwargs) def log_success_event( self, @@ -58,18 +49,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc] start_time: float, end_time: float, ) -> None: - """Log successful LLM API call and track token usage. - - This method has the same interface as litellm's CustomLogger.log_success_event() - so it can be used as a litellm callback when litellm is installed, or called - directly when litellm is not installed. - - Args: - kwargs: The arguments passed to the LLM call. - response_obj: The response object from the LLM API. - start_time: The timestamp when the call started. - end_time: The timestamp when the call completed. - """ + """Log successful LLM API call and track token usage.""" if self.token_cost_process is None: return