mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 22:49:23 +00:00
Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/imp/memory-prompt-influence
This commit is contained in:
@@ -9,7 +9,7 @@ authors = [
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"Pillow~=12.1.1",
|
||||
"pypdf~=6.9.1",
|
||||
"pypdf~=6.10.0",
|
||||
"python-magic>=0.4.27",
|
||||
"aiocache~=0.12.3",
|
||||
"aiofiles~=24.1.0",
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.14.0"
|
||||
__version__ = "1.14.2rc1"
|
||||
|
||||
@@ -9,8 +9,8 @@ authors = [
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"crewai==1.14.0",
|
||||
"requests>=2.33.0,<3",
|
||||
"crewai==1.14.2rc1",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
|
||||
@@ -305,4 +305,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.14.0"
|
||||
__version__ = "1.14.2rc1"
|
||||
|
||||
@@ -154,21 +154,19 @@ class ToolSpecExtractor:
|
||||
|
||||
return default_value
|
||||
|
||||
# Dynamically computed from BaseTool so that any future fields or
|
||||
# computed_fields added to BaseTool are automatically excluded from
|
||||
# the generated spec — no hardcoded denylist to maintain.
|
||||
# ``package_dependencies`` is not a BaseTool field but is extracted
|
||||
# into its own top-level key, so it's also excluded from init_params.
|
||||
_BASE_TOOL_FIELDS: set[str] = (
|
||||
set(BaseTool.model_fields)
|
||||
| set(BaseTool.model_computed_fields)
|
||||
| {"package_dependencies"}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_init_params(tool_class: type[BaseTool]) -> dict[str, Any]:
|
||||
ignored_init_params = [
|
||||
"name",
|
||||
"description",
|
||||
"env_vars",
|
||||
"args_schema",
|
||||
"description_updated",
|
||||
"cache_function",
|
||||
"result_as_answer",
|
||||
"max_usage_count",
|
||||
"current_usage_count",
|
||||
"package_dependencies",
|
||||
]
|
||||
|
||||
json_schema = tool_class.model_json_schema(
|
||||
schema_generator=SchemaGenerator, mode="serialization"
|
||||
)
|
||||
@@ -176,8 +174,14 @@ class ToolSpecExtractor:
|
||||
json_schema["properties"] = {
|
||||
key: value
|
||||
for key, value in json_schema["properties"].items()
|
||||
if key not in ignored_init_params
|
||||
if key not in ToolSpecExtractor._BASE_TOOL_FIELDS
|
||||
}
|
||||
if "required" in json_schema:
|
||||
json_schema["required"] = [
|
||||
key
|
||||
for key in json_schema["required"]
|
||||
if key not in ToolSpecExtractor._BASE_TOOL_FIELDS
|
||||
]
|
||||
return json_schema
|
||||
|
||||
def save_to_json(self, output_path: str) -> None:
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
try:
|
||||
@@ -12,6 +22,186 @@ try:
|
||||
except ImportError:
|
||||
SQLALCHEMY_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Commands allowed in read-only mode
|
||||
# NOTE: WITH is intentionally excluded — writable CTEs start with WITH, so the
|
||||
# CTE body must be inspected separately (see _validate_statement).
|
||||
_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
|
||||
|
||||
# Commands that mutate state and are blocked by default
|
||||
_WRITE_COMMANDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"MERGE",
|
||||
"REPLACE",
|
||||
"UPSERT",
|
||||
"LOAD",
|
||||
"COPY",
|
||||
"VACUUM",
|
||||
"ANALYZE",
|
||||
"ANALYSE",
|
||||
"REINDEX",
|
||||
"CLUSTER",
|
||||
"REFRESH",
|
||||
"COMMENT",
|
||||
"SET",
|
||||
"RESET",
|
||||
}
|
||||
|
||||
|
||||
# Subset of write commands that can realistically appear *inside* a CTE body.
|
||||
# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like
|
||||
# ``comment``, ``set``, or ``reset`` which are common column/table names.
|
||||
_CTE_WRITE_INDICATORS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"MERGE",
|
||||
}
|
||||
|
||||
|
||||
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
||||
|
||||
|
||||
def _iter_as_paren_matches(stmt: str) -> Iterator[re.Match[str]]:
|
||||
"""Yield regex matches for ``AS\\s*(`` outside of string literals."""
|
||||
# Build a set of character positions that are inside string literals.
|
||||
in_string: set[int] = set()
|
||||
i = 0
|
||||
while i < len(stmt):
|
||||
if stmt[i] == "'":
|
||||
start = i
|
||||
end = _skip_string_literal(stmt, i)
|
||||
in_string.update(range(start, end))
|
||||
i = end
|
||||
else:
|
||||
i += 1
|
||||
|
||||
for m in _AS_PAREN_RE.finditer(stmt):
|
||||
if m.start() not in in_string:
|
||||
yield m
|
||||
|
||||
|
||||
def _detect_writable_cte(stmt: str) -> str | None:
|
||||
"""Return the first write command inside a CTE body, or None.
|
||||
|
||||
Instead of tokenizing the whole statement (which falsely matches column
|
||||
names like ``comment``), this walks through parenthesized CTE bodies and
|
||||
checks only the *first keyword after* an opening ``AS (`` for a write
|
||||
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
|
||||
between ``AS`` and ``(``. Skips matches inside string literals.
|
||||
"""
|
||||
for m in _iter_as_paren_matches(stmt):
|
||||
body = stmt[m.end() :].lstrip()
|
||||
first_word = body.split()[0].upper().strip("()") if body.split() else ""
|
||||
if first_word in _CTE_WRITE_INDICATORS:
|
||||
return first_word
|
||||
return None
|
||||
|
||||
|
||||
def _skip_string_literal(stmt: str, pos: int) -> int:
|
||||
"""Skip past a string literal starting at pos (single-quoted).
|
||||
|
||||
Handles escaped quotes ('') inside the literal.
|
||||
Returns the index after the closing quote.
|
||||
"""
|
||||
quote_char = stmt[pos]
|
||||
i = pos + 1
|
||||
while i < len(stmt):
|
||||
if stmt[i] == quote_char:
|
||||
# Check for escaped quote ('')
|
||||
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
|
||||
i += 2
|
||||
continue
|
||||
return i + 1
|
||||
i += 1
|
||||
return i # Unterminated literal — return end
|
||||
|
||||
|
||||
def _find_matching_close_paren(stmt: str, start: int) -> int:
|
||||
"""Find the matching close paren, skipping string literals."""
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(stmt) and depth > 0:
|
||||
ch = stmt[i]
|
||||
if ch == "'":
|
||||
i = _skip_string_literal(stmt, i)
|
||||
continue
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
elif ch == ")":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_main_query_after_cte(stmt: str) -> str | None:
|
||||
"""Extract the main (outer) query that follows all CTE definitions.
|
||||
|
||||
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
|
||||
Returns None if no main query is found after the last CTE body.
|
||||
Handles parentheses inside string literals (e.g., ``SELECT '(' FROM t``).
|
||||
"""
|
||||
last_cte_end = 0
|
||||
for m in _iter_as_paren_matches(stmt):
|
||||
last_cte_end = _find_matching_close_paren(stmt, m.end())
|
||||
|
||||
if last_cte_end > 0:
|
||||
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
|
||||
if remainder:
|
||||
return remainder
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_explain_command(stmt: str) -> str | None:
|
||||
"""Resolve the underlying command from an EXPLAIN [ANALYZE] [VERBOSE] statement.
|
||||
|
||||
Returns the real command (e.g., 'DELETE') if ANALYZE is present, else None.
|
||||
Handles both space-separated and parenthesized syntax.
|
||||
"""
|
||||
rest = stmt.strip()[len("EXPLAIN") :].strip()
|
||||
if not rest:
|
||||
return None
|
||||
|
||||
analyze_found = False
|
||||
explain_opts = {"ANALYZE", "ANALYSE", "VERBOSE"}
|
||||
|
||||
if rest.startswith("("):
|
||||
close = rest.find(")")
|
||||
if close != -1:
|
||||
options_str = rest[1:close].upper()
|
||||
analyze_found = any(
|
||||
opt.strip() in ("ANALYZE", "ANALYSE") for opt in options_str.split(",")
|
||||
)
|
||||
rest = rest[close + 1 :].strip()
|
||||
else:
|
||||
while rest:
|
||||
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
|
||||
if first_opt in ("ANALYZE", "ANALYSE"):
|
||||
analyze_found = True
|
||||
if first_opt not in explain_opts:
|
||||
break
|
||||
rest = rest[len(first_opt) :].strip()
|
||||
|
||||
if analyze_found and rest:
|
||||
return rest.split()[0].upper().rstrip(";")
|
||||
return None
|
||||
|
||||
|
||||
class NL2SQLToolInput(BaseModel):
|
||||
sql_query: str = Field(
|
||||
@@ -21,20 +211,70 @@ class NL2SQLToolInput(BaseModel):
|
||||
|
||||
|
||||
class NL2SQLTool(BaseTool):
|
||||
"""Tool that converts natural language to SQL and executes it against a database.
|
||||
|
||||
By default the tool operates in **read-only mode**: only SELECT, SHOW,
|
||||
DESCRIBE, EXPLAIN, and read-only CTEs (WITH … SELECT) are permitted. Write
|
||||
operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are
|
||||
blocked unless ``allow_dml=True`` is set explicitly or the environment
|
||||
variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present.
|
||||
|
||||
Writable CTEs (``WITH d AS (DELETE …) SELECT …``) and
|
||||
``EXPLAIN ANALYZE <write-stmt>`` are treated as write operations and are
|
||||
blocked in read-only mode.
|
||||
|
||||
The ``_fetch_all_available_columns`` helper uses parameterised queries so
|
||||
that table names coming from the database catalogue cannot be used as an
|
||||
injection vector.
|
||||
"""
|
||||
|
||||
name: str = "NL2SQLTool"
|
||||
description: str = "Converts natural language to SQL queries and executes them."
|
||||
description: str = (
|
||||
"Converts natural language to SQL queries and executes them against a "
|
||||
"database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN "
|
||||
"queries (and read-only CTEs) are allowed unless configured with "
|
||||
"allow_dml=True."
|
||||
)
|
||||
db_uri: str = Field(
|
||||
title="Database URI",
|
||||
description="The URI of the database to connect to.",
|
||||
)
|
||||
allow_dml: bool = Field(
|
||||
default=False,
|
||||
title="Allow DML",
|
||||
description=(
|
||||
"When False (default) only read statements are permitted. "
|
||||
"Set to True to allow INSERT/UPDATE/DELETE/DROP and other "
|
||||
"write operations."
|
||||
),
|
||||
)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _apply_env_override(self) -> Self:
|
||||
"""Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime."""
|
||||
if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true":
|
||||
if not self.allow_dml:
|
||||
logger.warning(
|
||||
"NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — "
|
||||
"DML/DDL operations are enabled. Ensure this is intentional."
|
||||
)
|
||||
self.allow_dml = True
|
||||
return self
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
"sqlalchemy is not installed. Please install it with "
|
||||
"`pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
if self.allow_dml:
|
||||
logger.warning(
|
||||
"NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/"
|
||||
"DELETE/DROP/…) are permitted. Use with caution."
|
||||
)
|
||||
|
||||
data: dict[str, list[dict[str, Any]] | str] = {}
|
||||
@@ -50,42 +290,216 @@ class NL2SQLTool(BaseTool):
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_query(self, sql_query: str) -> None:
|
||||
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
||||
|
||||
Splits the query on semicolons and validates each statement
|
||||
independently. When ``allow_dml=False`` (the default), multi-statement
|
||||
queries are rejected outright to prevent ``SELECT 1; DROP TABLE users``
|
||||
style bypasses. When ``allow_dml=True`` every statement is checked and
|
||||
a warning is emitted for write operations.
|
||||
"""
|
||||
statements = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
|
||||
if not statements:
|
||||
raise ValueError("NL2SQLTool received an empty SQL query.")
|
||||
|
||||
if not self.allow_dml and len(statements) > 1:
|
||||
raise ValueError(
|
||||
"NL2SQLTool blocked a multi-statement query in read-only mode. "
|
||||
"Semicolons are not permitted when allow_dml=False."
|
||||
)
|
||||
|
||||
for stmt in statements:
|
||||
self._validate_statement(stmt)
|
||||
|
||||
def _validate_statement(self, stmt: str) -> None:
|
||||
"""Validate a single SQL statement (no semicolons)."""
|
||||
command = self._extract_command(stmt)
|
||||
|
||||
# EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying
|
||||
# query. Resolve the real command so write operations are caught.
|
||||
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
|
||||
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
|
||||
# EXPLAIN ANALYZE actually executes the underlying query — resolve the
|
||||
# real command so write operations are caught.
|
||||
if command == "EXPLAIN":
|
||||
resolved = _resolve_explain_command(stmt)
|
||||
if resolved:
|
||||
command = resolved
|
||||
|
||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||
if command == "WITH":
|
||||
# Check for write commands inside CTE bodies.
|
||||
write_found = _detect_writable_cte(stmt)
|
||||
if write_found:
|
||||
found = write_found
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"writable CTE containing a '{found}' statement. To allow "
|
||||
f"write operations set allow_dml=True or "
|
||||
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
||||
found,
|
||||
)
|
||||
return
|
||||
|
||||
# Check the main query after the CTE definitions.
|
||||
main_query = _extract_main_query_after_cte(stmt)
|
||||
if main_query:
|
||||
main_cmd = main_query.split()[0].upper().rstrip(";")
|
||||
if main_cmd in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"'{main_cmd}' statement after a CTE. To allow write "
|
||||
f"operations set allow_dml=True or "
|
||||
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
|
||||
main_cmd,
|
||||
)
|
||||
elif main_cmd not in _READ_ONLY_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool blocked an unrecognised SQL command '{main_cmd}' "
|
||||
f"after a CTE. Only {sorted(_READ_ONLY_COMMANDS)} are allowed "
|
||||
f"in read-only mode."
|
||||
)
|
||||
return
|
||||
|
||||
if command in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"'{command}' statement. To allow write operations set "
|
||||
f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing write statement '%s' because allow_dml=True.",
|
||||
command,
|
||||
)
|
||||
elif command not in _READ_ONLY_COMMANDS:
|
||||
# Unknown command — block by default unless DML is explicitly enabled
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool blocked an unrecognised SQL command '{command}'. "
|
||||
f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only "
|
||||
f"mode."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_command(sql_query: str) -> str:
|
||||
"""Return the uppercased first keyword of *sql_query*."""
|
||||
stripped = sql_query.strip().lstrip("(")
|
||||
first_token = stripped.split()[0] if stripped.split() else ""
|
||||
return first_token.upper().rstrip(";")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema introspection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
"SELECT table_name FROM information_schema.tables "
|
||||
"WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
def _fetch_all_available_columns(
|
||||
self, table_name: str
|
||||
) -> list[dict[str, Any]] | str:
|
||||
"""Fetch columns for *table_name* using a parameterised query.
|
||||
|
||||
The table name is bound via SQLAlchemy's ``:param`` syntax to prevent
|
||||
SQL injection from catalogue values.
|
||||
"""
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
"SELECT column_name, data_type FROM information_schema.columns "
|
||||
"WHERE table_name = :table_name",
|
||||
params={"table_name": table_name},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
self._validate_query(sql_query)
|
||||
data = self.execute_sql(sql_query)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
data = (
|
||||
f"Based on these tables {self.tables} and columns {self.columns}, "
|
||||
"you can create SQL queries to retrieve data from the database."
|
||||
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query."
|
||||
"you can create SQL queries to retrieve data from the database. "
|
||||
f"Get the original request {sql_query} and the error {exc} and "
|
||||
"create the correct SQL query."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
def execute_sql(
|
||||
self,
|
||||
sql_query: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]] | str:
|
||||
"""Execute *sql_query* and return the results as a list of dicts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sql_query:
|
||||
The SQL statement to run.
|
||||
params:
|
||||
Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``).
|
||||
"""
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
"sqlalchemy is not installed. Please install it with "
|
||||
"`pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
||||
# commit when allow_dml=True, regardless of statement order.
|
||||
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
|
||||
def _is_write_stmt(s: str) -> bool:
|
||||
cmd = self._extract_command(s)
|
||||
if cmd in _WRITE_COMMANDS:
|
||||
return True
|
||||
if cmd == "EXPLAIN":
|
||||
# Resolve the underlying command for EXPLAIN ANALYZE
|
||||
resolved = _resolve_explain_command(s)
|
||||
if resolved and resolved in _WRITE_COMMANDS:
|
||||
return True
|
||||
if cmd == "WITH":
|
||||
if _detect_writable_cte(s):
|
||||
return True
|
||||
main_q = _extract_main_query_after_cte(s)
|
||||
if main_q:
|
||||
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
|
||||
return False
|
||||
|
||||
is_write = any(_is_write_stmt(s) for s in _stmts)
|
||||
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine) # noqa: N806
|
||||
session = Session()
|
||||
try:
|
||||
result = session.execute(text(sql_query))
|
||||
session.commit()
|
||||
result = session.execute(text(sql_query), params or {})
|
||||
|
||||
# Only commit when the operation actually mutates state
|
||||
if self.allow_dml and is_write:
|
||||
session.commit()
|
||||
|
||||
if result.returns_rows: # type: ignore[attr-defined]
|
||||
columns = result.keys()
|
||||
|
||||
@@ -45,6 +45,26 @@ class MockTool(BaseTool):
|
||||
)
|
||||
|
||||
|
||||
# --- Intermediate base class (like RagTool, BraveSearchToolBase) ---
|
||||
class MockIntermediateBase(BaseTool):
|
||||
"""Simulates an intermediate tool base class (e.g. RagTool, BraveSearchToolBase)."""
|
||||
|
||||
name: str = "Intermediate Base"
|
||||
description: str = "An intermediate tool base"
|
||||
shared_config: str = Field("default_config", description="Config from intermediate base")
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return query
|
||||
|
||||
|
||||
class MockDerivedTool(MockIntermediateBase):
|
||||
"""A tool inheriting from an intermediate base, like CodeDocsSearchTool(RagTool)."""
|
||||
|
||||
name: str = "Derived Tool"
|
||||
description: str = "A tool that inherits from intermediate base"
|
||||
derived_param: str = Field("derived_default", description="Param specific to derived tool")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor():
|
||||
ext = ToolSpecExtractor()
|
||||
@@ -169,6 +189,87 @@ def test_extract_package_dependencies(mock_tool_extractor):
|
||||
]
|
||||
|
||||
|
||||
def test_base_tool_fields_excluded_from_init_params(mock_tool_extractor):
|
||||
"""BaseTool internal fields (including computed_field like tool_type) must
|
||||
never appear in init_params_schema. Studio reads this schema to render
|
||||
the tool config UI — internal fields confuse users."""
|
||||
init_schema = mock_tool_extractor["init_params_schema"]
|
||||
props = set(init_schema.get("properties", {}).keys())
|
||||
required = set(init_schema.get("required", []))
|
||||
|
||||
# These are all BaseTool's own fields — none should leak
|
||||
base_fields = {"name", "description", "env_vars", "args_schema",
|
||||
"description_updated", "cache_function", "result_as_answer",
|
||||
"max_usage_count", "current_usage_count", "tool_type",
|
||||
"package_dependencies"}
|
||||
|
||||
leaked_props = base_fields & props
|
||||
assert not leaked_props, (
|
||||
f"BaseTool fields leaked into init_params_schema properties: {leaked_props}"
|
||||
)
|
||||
leaked_required = base_fields & required
|
||||
assert not leaked_required, (
|
||||
f"BaseTool fields leaked into init_params_schema required: {leaked_required}"
|
||||
)
|
||||
|
||||
|
||||
def test_intermediate_base_fields_preserved_for_derived_tool(extractor):
|
||||
"""When a tool inherits from an intermediate base (e.g. RagTool),
|
||||
the intermediate's fields should be included — only BaseTool's own
|
||||
fields are excluded."""
|
||||
with (
|
||||
mock.patch(
|
||||
"crewai_tools.generate_tool_specs.dir",
|
||||
return_value=["MockDerivedTool"],
|
||||
),
|
||||
mock.patch(
|
||||
"crewai_tools.generate_tool_specs.getattr",
|
||||
return_value=MockDerivedTool,
|
||||
),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
tool_info = extractor.tools_spec[0]
|
||||
|
||||
props = set(tool_info["init_params_schema"].get("properties", {}).keys())
|
||||
|
||||
# Intermediate base's field should be preserved
|
||||
assert "shared_config" in props, (
|
||||
"Intermediate base class fields should be preserved in init_params_schema"
|
||||
)
|
||||
# Derived tool's own field should be preserved
|
||||
assert "derived_param" in props, (
|
||||
"Derived tool's own fields should be preserved in init_params_schema"
|
||||
)
|
||||
# BaseTool internals should still be excluded
|
||||
assert "tool_type" not in props
|
||||
assert "cache_function" not in props
|
||||
assert "result_as_answer" not in props
|
||||
|
||||
|
||||
def test_future_base_tool_field_auto_excluded(extractor):
|
||||
"""If a new field is added to BaseTool in the future, it should be
|
||||
automatically excluded from spec generation without needing to update
|
||||
the ignored list. This test verifies the allowlist approach works
|
||||
by checking that ONLY non-BaseTool fields appear."""
|
||||
with (
|
||||
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
tool_info = extractor.tools_spec[0]
|
||||
|
||||
props = set(tool_info["init_params_schema"].get("properties", {}).keys())
|
||||
base_all = set(BaseTool.model_fields) | set(BaseTool.model_computed_fields)
|
||||
|
||||
leaked = base_all & props
|
||||
assert not leaked, (
|
||||
f"BaseTool fields should be auto-excluded but found: {leaked}. "
|
||||
"The spec generator should dynamically compute BaseTool's fields "
|
||||
"instead of using a hardcoded denylist."
|
||||
)
|
||||
|
||||
|
||||
def test_save_to_json(extractor, tmp_path):
|
||||
extractor.tools_spec = [
|
||||
{
|
||||
|
||||
671
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
671
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
@@ -0,0 +1,671 @@
|
||||
"""Security tests for NL2SQLTool.
|
||||
|
||||
Uses an in-memory SQLite database so no external service is needed.
|
||||
SQLite does not have information_schema, so we patch the schema-introspection
|
||||
helpers to avoid bootstrap failures and focus purely on the security logic.
|
||||
"""
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip the entire module if SQLAlchemy is not installed
|
||||
pytest.importorskip("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text # noqa: E402
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_URI = "sqlite://" # in-memory
|
||||
|
||||
|
||||
def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
|
||||
"""Return a NL2SQLTool wired to an in-memory SQLite DB.
|
||||
|
||||
Schema-introspection is patched out so we can create the tool without a
|
||||
real PostgreSQL information_schema.
|
||||
"""
|
||||
with (
|
||||
patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]),
|
||||
patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]),
|
||||
):
|
||||
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-only enforcement (allow_dml=False)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadOnlyMode:
|
||||
def test_select_allowed_by_default(self):
|
||||
tool = _make_tool()
|
||||
# SQLite supports SELECT without information_schema
|
||||
result = tool.execute_sql("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stmt",
|
||||
[
|
||||
"INSERT INTO t VALUES (1)",
|
||||
"UPDATE t SET col = 1",
|
||||
"DELETE FROM t",
|
||||
"DROP TABLE t",
|
||||
"ALTER TABLE t ADD col TEXT",
|
||||
"CREATE TABLE t (id INTEGER)",
|
||||
"TRUNCATE TABLE t",
|
||||
"GRANT SELECT ON t TO user1",
|
||||
"REVOKE SELECT ON t FROM user1",
|
||||
"EXEC sp_something",
|
||||
"EXECUTE sp_something",
|
||||
"CALL proc()",
|
||||
],
|
||||
)
|
||||
def test_write_statements_blocked_by_default(self, stmt: str):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(stmt)
|
||||
|
||||
def test_explain_allowed(self):
|
||||
tool = _make_tool()
|
||||
# Should not raise
|
||||
tool._validate_query("EXPLAIN SELECT 1")
|
||||
|
||||
def test_read_only_cte_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte")
|
||||
|
||||
def test_show_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("SHOW TABLES")
|
||||
|
||||
def test_describe_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("DESCRIBE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DML enabled (allow_dml=True)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDMLEnabled:
|
||||
def test_insert_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
# Should not raise
|
||||
tool._validate_query("INSERT INTO t VALUES (1)")
|
||||
|
||||
def test_delete_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query("DELETE FROM t WHERE id = 1")
|
||||
|
||||
def test_drop_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query("DROP TABLE t")
|
||||
|
||||
def test_dml_actually_persists(self):
|
||||
"""End-to-end: INSERT commits when allow_dml=True."""
|
||||
# Use a file-based SQLite so we can verify persistence across sessions
|
||||
import tempfile, os
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
uri = f"sqlite:///{db_path}"
|
||||
try:
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool.db_uri = uri
|
||||
|
||||
engine = create_engine(uri)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)"))
|
||||
conn.commit()
|
||||
|
||||
tool.execute_sql("INSERT INTO items VALUES (42)")
|
||||
|
||||
with engine.connect() as conn:
|
||||
rows = conn.execute(text("SELECT id FROM items")).fetchall()
|
||||
assert (42,) in rows
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameterised query — SQL injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParameterisedQueries:
|
||||
def test_table_name_is_parameterised(self):
|
||||
"""_fetch_all_available_columns must not interpolate table_name into SQL."""
|
||||
tool = _make_tool()
|
||||
captured_calls = []
|
||||
|
||||
def recording_execute_sql(self_inner, sql_query, params=None):
|
||||
captured_calls.append((sql_query, params))
|
||||
return []
|
||||
|
||||
with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql):
|
||||
tool._fetch_all_available_columns("users'; DROP TABLE users; --")
|
||||
|
||||
assert len(captured_calls) == 1
|
||||
sql, params = captured_calls[0]
|
||||
# The raw SQL must NOT contain the injected string
|
||||
assert "DROP" not in sql
|
||||
# The table name must be passed as a parameter
|
||||
assert params is not None
|
||||
assert params.get("table_name") == "users'; DROP TABLE users; --"
|
||||
# The SQL template must use the :param syntax
|
||||
assert ":table_name" in sql
|
||||
|
||||
def test_injection_string_not_in_sql_template(self):
|
||||
"""The f-string vulnerability is gone — table name never lands in the SQL."""
|
||||
tool = _make_tool()
|
||||
injection = "'; DROP TABLE users; --"
|
||||
captured = {}
|
||||
|
||||
def spy(self_inner, sql_query, params=None):
|
||||
captured["sql"] = sql_query
|
||||
captured["params"] = params
|
||||
return []
|
||||
|
||||
with patch.object(NL2SQLTool, "execute_sql", spy):
|
||||
tool._fetch_all_available_columns(injection)
|
||||
|
||||
assert injection not in captured["sql"]
|
||||
assert captured["params"]["table_name"] == injection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.commit() not called for read-only queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoCommitForReadOnly:
|
||||
def test_select_does_not_commit(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["val"]
|
||||
mock_result.fetchall.return_value = [(1,)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("SELECT 1")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_write_with_dml_enabled_does_commit(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = False
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("INSERT INTO t VALUES (1)")
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment-variable escape hatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarEscapeHatch:
|
||||
def test_env_var_enables_dml(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is True
|
||||
|
||||
def test_env_var_case_insensitive(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is True
|
||||
|
||||
def test_env_var_absent_keeps_default(self):
|
||||
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is False
|
||||
|
||||
def test_env_var_false_does_not_enable_dml(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is False
|
||||
|
||||
def test_dml_write_blocked_without_env_var(self):
|
||||
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("DROP TABLE sensitive_data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run() propagates ValueError from _validate_query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
def test_run_raises_on_blocked_query(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._run("DELETE FROM users")
|
||||
|
||||
def test_run_returns_results_for_select(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
result = tool._run("SELECT 1 AS n")
|
||||
assert result == [{"n": 1}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-statement / semicolon injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSemicolonInjection:
|
||||
def test_multi_statement_blocked_in_read_only_mode(self):
|
||||
"""SELECT 1; DROP TABLE users must be rejected when allow_dml=False."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="multi-statement"):
|
||||
tool._validate_query("SELECT 1; DROP TABLE users")
|
||||
|
||||
def test_multi_statement_blocked_even_with_only_selects(self):
|
||||
"""Two SELECT statements are still rejected in read-only mode."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="multi-statement"):
|
||||
tool._validate_query("SELECT 1; SELECT 2")
|
||||
|
||||
def test_trailing_semicolon_allowed_single_statement(self):
|
||||
"""A single statement with a trailing semicolon should pass."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
# Should not raise — the part after the semicolon is empty
|
||||
tool._validate_query("SELECT 1;")
|
||||
|
||||
def test_multi_statement_allowed_when_dml_enabled(self):
|
||||
"""Multiple statements are permitted when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
# Should not raise
|
||||
tool._validate_query("SELECT 1; INSERT INTO t VALUES (1)")
|
||||
|
||||
def test_multi_statement_write_still_blocked_individually(self):
|
||||
"""Even with allow_dml=False, a single write statement is blocked."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("DROP TABLE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Writable CTEs (WITH … DELETE/INSERT/UPDATE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWritableCTE:
|
||||
def test_writable_cte_delete_blocked_in_read_only(self):
|
||||
"""WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d — blocked."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted"
|
||||
)
|
||||
|
||||
def test_writable_cte_insert_blocked_in_read_only(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH ins AS (INSERT INTO t VALUES (1) RETURNING id) SELECT * FROM ins"
|
||||
)
|
||||
|
||||
def test_writable_cte_update_blocked_in_read_only(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH upd AS (UPDATE t SET x=1 RETURNING id) SELECT * FROM upd"
|
||||
)
|
||||
|
||||
def test_writable_cte_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
# Should not raise
|
||||
tool._validate_query(
|
||||
"WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted"
|
||||
)
|
||||
|
||||
def test_plain_read_only_cte_still_allowed(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
# No write commands in the CTE body — must pass
|
||||
tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte")
|
||||
|
||||
def test_cte_with_comment_column_not_false_positive(self):
|
||||
"""Column named 'comment' should NOT trigger writable CTE detection."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
# 'comment' is a column name, not a SQL command
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT comment FROM posts) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_set_column_not_false_positive(self):
|
||||
"""Column named 'set' should NOT trigger writable CTE detection."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT set, reset FROM config) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE executes the underlying query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cte_with_write_main_query_blocked(self):
|
||||
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT 1) DELETE FROM users"
|
||||
)
|
||||
|
||||
def test_cte_with_write_main_query_allowed_with_dml(self):
|
||||
"""Main query write after CTE should pass when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT id FROM users) INSERT INTO archive SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_newline_before_paren_blocked(self):
|
||||
"""AS followed by newline then ( should still detect writable CTE."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS\n(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_tab_before_paren_blocked(self):
|
||||
"""AS followed by tab then ( should still detect writable CTE."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS\t(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
class TestExplainAnalyze:
|
||||
def test_explain_analyze_delete_blocked_in_read_only(self):
|
||||
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN ANALYZE DELETE FROM users")
|
||||
|
||||
def test_explain_analyse_delete_blocked_in_read_only(self):
|
||||
"""British spelling ANALYSE is also caught."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN ANALYSE DELETE FROM users")
|
||||
|
||||
def test_explain_analyze_drop_blocked_in_read_only(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN ANALYZE DROP TABLE users")
|
||||
|
||||
def test_explain_analyze_select_allowed_in_read_only(self):
|
||||
"""EXPLAIN ANALYZE on a SELECT is safe — must be permitted."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query("EXPLAIN ANALYZE SELECT * FROM users")
|
||||
|
||||
def test_explain_without_analyze_allowed(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query("EXPLAIN SELECT * FROM users")
|
||||
|
||||
def test_explain_analyze_delete_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query("EXPLAIN ANALYZE DELETE FROM users")
|
||||
|
||||
def test_explain_paren_analyze_delete_blocked_in_read_only(self):
|
||||
"""EXPLAIN (ANALYZE) DELETE actually runs the delete — block it."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN (ANALYZE) DELETE FROM users")
|
||||
|
||||
def test_explain_paren_analyze_verbose_delete_blocked_in_read_only(self):
|
||||
"""EXPLAIN (ANALYZE, VERBOSE) DELETE actually runs the delete — block it."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN (ANALYZE, VERBOSE) DELETE FROM users")
|
||||
|
||||
def test_explain_paren_verbose_select_allowed_in_read_only(self):
|
||||
"""EXPLAIN (VERBOSE) SELECT is safe — no ANALYZE means no execution."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query("EXPLAIN (VERBOSE) SELECT * FROM users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-statement commit covers ALL statements (not just the first)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiStatementCommit:
|
||||
def test_select_then_insert_triggers_commit(self):
|
||||
"""SELECT 1; INSERT … — commit must happen because INSERT is a write."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = False
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("SELECT 1; INSERT INTO t VALUES (1)")
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_select_only_multi_statement_does_not_commit(self):
|
||||
"""Two SELECTs must not trigger a commit even when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["v"]
|
||||
mock_result.fetchall.return_value = [(1,)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("SELECT 1; SELECT 2")
|
||||
|
||||
def test_writable_cte_triggers_commit(self):
|
||||
"""WITH d AS (DELETE ...) must trigger commit when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["id"]
|
||||
mock_result.fetchall.return_value = [(1,)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql(
|
||||
"WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d"
|
||||
)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended _WRITE_COMMANDS coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtendedWriteCommands:
|
||||
@pytest.mark.parametrize(
|
||||
"stmt",
|
||||
[
|
||||
"UPSERT INTO t VALUES (1)",
|
||||
"LOAD DATA INFILE 'f.csv' INTO TABLE t",
|
||||
"COPY t FROM '/tmp/f.csv'",
|
||||
"VACUUM ANALYZE t",
|
||||
"ANALYZE t",
|
||||
"ANALYSE t",
|
||||
"REINDEX TABLE t",
|
||||
"CLUSTER t USING idx",
|
||||
"REFRESH MATERIALIZED VIEW v",
|
||||
"COMMENT ON TABLE t IS 'desc'",
|
||||
"SET search_path = myschema",
|
||||
"RESET search_path",
|
||||
],
|
||||
)
|
||||
def test_extended_write_commands_blocked_by_default(self, stmt: str):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(stmt)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE VERBOSE handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExplainAnalyzeVerbose:
|
||||
def test_explain_analyze_verbose_select_allowed(self):
|
||||
"""EXPLAIN ANALYZE VERBOSE SELECT should be allowed (read-only)."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query("EXPLAIN ANALYZE VERBOSE SELECT * FROM users")
|
||||
|
||||
def test_explain_analyze_verbose_delete_blocked(self):
|
||||
"""EXPLAIN ANALYZE VERBOSE DELETE should be blocked."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("EXPLAIN ANALYZE VERBOSE DELETE FROM users")
|
||||
|
||||
def test_explain_verbose_select_allowed(self):
|
||||
"""EXPLAIN VERBOSE SELECT (no ANALYZE) should be allowed."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query("EXPLAIN VERBOSE SELECT * FROM users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CTE with string literal parens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEStringLiteralParens:
|
||||
def test_cte_string_paren_does_not_bypass(self):
|
||||
"""Parens inside string literals should not confuse the paren walker."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT '(' FROM t) DELETE FROM users"
|
||||
)
|
||||
|
||||
def test_cte_string_paren_read_only_allowed(self):
|
||||
"""Read-only CTE with string literal parens should be allowed."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT '(' FROM t) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE commit logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExplainAnalyzeCommit:
|
||||
def test_explain_analyze_delete_triggers_commit(self):
|
||||
"""EXPLAIN ANALYZE DELETE should trigger commit when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["QUERY PLAN"]
|
||||
mock_result.fetchall.return_value = [("Delete on users",)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("EXPLAIN ANALYZE DELETE FROM users")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AS( inside string literals must not confuse CTE detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEStringLiteralAS:
|
||||
def test_as_paren_inside_string_does_not_bypass(self):
|
||||
"""'AS (' inside a string literal must not be treated as a CTE body."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT 'AS (' FROM t) DELETE FROM users"
|
||||
)
|
||||
|
||||
def test_as_paren_inside_string_read_only_ok(self):
|
||||
"""Read-only CTE with 'AS (' in a string should be allowed."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT 'AS (' FROM t) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unknown command after CTE should be blocked
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEUnknownCommand:
|
||||
def test_unknown_command_after_cte_blocked(self):
|
||||
"""WITH cte AS (SELECT 1) FOOBAR should be blocked as unknown."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="unrecognised"):
|
||||
tool._validate_query("WITH cte AS (SELECT 1) FOOBAR")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic~=2.11.9",
|
||||
"openai>=1.83.0,<3",
|
||||
"openai>=2.0.0,<3",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber~=0.11.4",
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"pydantic-settings~=2.10.1",
|
||||
"httpx~=0.28.1",
|
||||
"mcp~=1.26.0",
|
||||
"uv~=0.9.13",
|
||||
"uv~=0.11.6",
|
||||
"aiosqlite~=0.21.0",
|
||||
"pyyaml~=6.0",
|
||||
"aiofiles~=24.1.0",
|
||||
@@ -55,7 +55,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.14.0",
|
||||
"crewai-tools==1.14.2rc1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -68,14 +68,14 @@ openpyxl = [
|
||||
]
|
||||
mem0 = ["mem0ai~=0.1.94"]
|
||||
docling = [
|
||||
"docling~=2.75.0",
|
||||
"docling~=2.84.0",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]~=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3~=1.40.38",
|
||||
"aiobotocore~=2.25.2",
|
||||
"boto3~=1.42.79",
|
||||
"aiobotocore~=3.4.0",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai~=1.3.39",
|
||||
@@ -87,7 +87,7 @@ litellm = [
|
||||
"litellm~=1.83.0",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3~=1.40.45",
|
||||
"boto3~=1.42.79",
|
||||
]
|
||||
google-genai = [
|
||||
"google-genai~=1.65.0",
|
||||
|
||||
@@ -46,7 +46,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.14.0"
|
||||
__version__ = "1.14.2rc1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -98,7 +98,6 @@ class A2AErrorCode(IntEnum):
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
|
||||
# Error code to default message mapping
|
||||
ERROR_MESSAGES: dict[int, str] = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
|
||||
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
|
||||
|
||||
@@ -63,25 +63,21 @@ class A2AExtension(Protocol):
|
||||
Example:
|
||||
class MyExtension:
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
# Add custom tools to the agent
|
||||
pass
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
# Extract state from conversation
|
||||
return None
|
||||
|
||||
def augment_prompt(
|
||||
self, base_prompt: str, conversation_state: ConversationState | None
|
||||
) -> str:
|
||||
# Add custom instructions
|
||||
return base_prompt
|
||||
|
||||
def process_response(
|
||||
self, agent_response: Any, conversation_state: ConversationState | None
|
||||
) -> Any:
|
||||
# Modify response if needed
|
||||
return agent_response
|
||||
"""
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ def extract_a2a_agent_ids_from_config(
|
||||
else:
|
||||
configs = a2a_config
|
||||
|
||||
# Filter to only client configs (those with endpoint)
|
||||
client_configs: list[A2AClientConfigTypes] = [
|
||||
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
|
||||
]
|
||||
|
||||
@@ -84,6 +84,7 @@ from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.skills.loader import activate_skill, discover_skills
|
||||
from crewai.skills.models import INSTRUCTIONS, Skill as SkillModel
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, apply_checkpoint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.agent_utils import (
|
||||
@@ -98,6 +99,7 @@ from crewai.utilities.converter import Converter, ConverterError
|
||||
from crewai.utilities.env import get_env_context
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.guardrail_types import GuardrailCallable, GuardrailType
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts, StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
@@ -499,8 +501,8 @@ class Agent(BaseAgent):
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt)
|
||||
task_prompt = format_task_with_context(task_prompt, context)
|
||||
return self._retrieve_memory_context(task, task_prompt)
|
||||
|
||||
def _finalize_task_prompt(
|
||||
@@ -562,7 +564,7 @@ class Agent(BaseAgent):
|
||||
m.format() for m in matches
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
task_prompt += I18N_DEFAULT.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -968,14 +970,13 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
has_tools=len(raw_tools) > 0,
|
||||
use_native_tool_calling=use_native_tool_calling,
|
||||
i18n=self.i18n,
|
||||
use_system_prompt=self.use_system_prompt,
|
||||
system_template=self.system_template,
|
||||
prompt_template=self.prompt_template,
|
||||
response_template=self.response_template,
|
||||
).task_execution()
|
||||
|
||||
stop_words = [self.i18n.slice("observation")]
|
||||
stop_words = [I18N_DEFAULT.slice("observation")]
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
@@ -1017,7 +1018,6 @@ class Agent(BaseAgent):
|
||||
self.agent_executor = self.executor_class(
|
||||
llm=self.llm,
|
||||
task=task,
|
||||
i18n=self.i18n,
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
@@ -1262,10 +1262,10 @@ class Agent(BaseAgent):
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
query = I18N_DEFAULT.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
)
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
rewriter_prompt = I18N_DEFAULT.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -1342,7 +1342,6 @@ class Agent(BaseAgent):
|
||||
|
||||
raw_tools: list[BaseTool] = self.tools or []
|
||||
|
||||
# Inject memory tools for standalone kickoff (crew path handles its own)
|
||||
agent_memory = getattr(self, "memory", None)
|
||||
if agent_memory is not None:
|
||||
from crewai.tools.memory_tools import create_memory_tools
|
||||
@@ -1384,7 +1383,6 @@ class Agent(BaseAgent):
|
||||
request_within_rpm_limit=rpm_limit_fn,
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
response_model=response_format,
|
||||
i18n=self.i18n,
|
||||
)
|
||||
|
||||
all_files: dict[str, Any] = {}
|
||||
@@ -1401,7 +1399,6 @@ class Agent(BaseAgent):
|
||||
if input_files:
|
||||
all_files.update(input_files)
|
||||
|
||||
# Inject memory context for standalone kickoff (recall before execution)
|
||||
if agent_memory is not None:
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
@@ -1420,7 +1417,7 @@ class Agent(BaseAgent):
|
||||
m.format() for m in matches
|
||||
)
|
||||
if memory_block:
|
||||
formatted_messages += "\n\n" + self.i18n.slice("memory").format(
|
||||
formatted_messages += "\n\n" + I18N_DEFAULT.slice("memory").format(
|
||||
memory=memory_block
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -1461,6 +1458,7 @@ class Agent(BaseAgent):
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> LiteAgentOutput | Coroutine[Any, Any, LiteAgentOutput]:
|
||||
"""Execute the agent with the given messages using the AgentExecutor.
|
||||
|
||||
@@ -1479,6 +1477,9 @@ class Agent(BaseAgent):
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
input_files: Optional dict of named files to attach to the message.
|
||||
Files can be paths, bytes, or File objects from crewai_files.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the agent resumes from that checkpoint. Remaining
|
||||
config fields enable checkpointing for the run.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
@@ -1487,8 +1488,14 @@ class Agent(BaseAgent):
|
||||
Note:
|
||||
For explicit async usage outside of Flow, use kickoff_async() directly.
|
||||
"""
|
||||
# Magic auto-async: if inside event loop (e.g., inside a Flow),
|
||||
# return coroutine for Flow to await
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return restored.kickoff( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
input_files=input_files,
|
||||
)
|
||||
|
||||
if is_inside_event_loop():
|
||||
return self.kickoff_async(messages, response_format, input_files)
|
||||
|
||||
@@ -1624,7 +1631,7 @@ class Agent(BaseAgent):
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = self.i18n.slice("formatted_task_instructions").format(
|
||||
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
@@ -1639,7 +1646,7 @@ class Agent(BaseAgent):
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
pass # Keep raw output if conversion fails
|
||||
pass
|
||||
else:
|
||||
raw_output = str(output) if not isinstance(output, str) else output
|
||||
|
||||
@@ -1721,7 +1728,6 @@ class Agent(BaseAgent):
|
||||
elif callable(self.guardrail):
|
||||
guardrail_callable = self.guardrail
|
||||
else:
|
||||
# Should not happen if called from kickoff with guardrail check
|
||||
return output
|
||||
|
||||
guardrail_result = process_guardrail(
|
||||
@@ -1767,6 +1773,7 @@ class Agent(BaseAgent):
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent asynchronously with the given messages.
|
||||
|
||||
@@ -1782,10 +1789,20 @@ class Agent(BaseAgent):
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
input_files: Optional dict of named files to attach to the message.
|
||||
Files can be paths, bytes, or File objects from crewai_files.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the agent resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return await restored.kickoff_async( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
input_files=input_files,
|
||||
)
|
||||
|
||||
executor, inputs, agent_info, parsed_tools = self._prepare_kickoff(
|
||||
messages, response_format, input_files
|
||||
)
|
||||
@@ -1815,6 +1832,7 @@ class Agent(BaseAgent):
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Async version of kickoff. Alias for kickoff_async.
|
||||
|
||||
@@ -1822,8 +1840,12 @@ class Agent(BaseAgent):
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
input_files: Optional dict of named files to attach to the message.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the agent resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
return await self.kickoff_async(messages, response_format, input_files)
|
||||
return await self.kickoff_async(
|
||||
messages, response_format, input_files, from_checkpoint
|
||||
)
|
||||
|
||||
@@ -41,7 +41,6 @@ class PlanningConfig(BaseModel):
|
||||
from crewai import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
|
||||
# Simple usage — fast, linear execution (default)
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
@@ -49,7 +48,6 @@ class PlanningConfig(BaseModel):
|
||||
planning_config=PlanningConfig(),
|
||||
)
|
||||
|
||||
# Balanced — replan only when steps fail
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
@@ -59,7 +57,6 @@ class PlanningConfig(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
# Full adaptive planning with refinement and replanning
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
@@ -69,7 +66,7 @@ class PlanningConfig(BaseModel):
|
||||
max_attempts=3,
|
||||
max_steps=10,
|
||||
plan_prompt="Create a focused plan for: {description}",
|
||||
llm="gpt-4o-mini", # Use cheaper model for planning
|
||||
llm="gpt-4o-mini",
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
@@ -40,7 +39,6 @@ def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
# Check if planning is enabled using the planning_enabled property
|
||||
if not getattr(agent, "planning_enabled", False):
|
||||
return
|
||||
|
||||
@@ -59,46 +57,50 @@ def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
agent._logger.log("error", f"Error during planning: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str) -> str:
|
||||
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||
|
||||
Args:
|
||||
task: The task being executed.
|
||||
task_prompt: The initial task prompt.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with schema instructions.
|
||||
"""
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += "\n" + I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += "\n" + I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||
def format_task_with_context(task_prompt: str, context: str | None) -> str:
|
||||
"""Format task prompt with context if provided.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt.
|
||||
context: Optional context string.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt formatted with context if provided.
|
||||
"""
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
if context:
|
||||
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||
return I18N_DEFAULT.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
@@ -186,7 +187,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
task_prompt = task.prompt() if hasattr(task, "prompt") else str(task)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task_prompt = I18N_DEFAULT.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from crewai.events.types.agent_events import (
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
@@ -133,7 +134,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
try:
|
||||
task_prompt: str = task.prompt()
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task_prompt = I18N_DEFAULT.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
|
||||
@@ -99,12 +99,10 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
# Get the parameter name from the schema
|
||||
param_name: str = next(
|
||||
iter(tool.args_schema.model_json_schema()["properties"].keys())
|
||||
)
|
||||
|
||||
# Handle different argument types
|
||||
args_dict: dict[str, Any]
|
||||
if isinstance(arguments, dict):
|
||||
args_dict = arguments
|
||||
@@ -116,16 +114,13 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
else:
|
||||
args_dict = {param_name: str(arguments)}
|
||||
|
||||
# Run the tool with the processed arguments
|
||||
output: Any | Awaitable[Any] = tool._run(**args_dict)
|
||||
|
||||
# Await if the tool returned a coroutine
|
||||
if inspect.isawaitable(output):
|
||||
result: Any = await output
|
||||
else:
|
||||
result = output
|
||||
|
||||
# Ensure the result is JSON serializable
|
||||
if isinstance(result, (dict, list, str, int, float, bool, type(None))):
|
||||
return result
|
||||
return str(result)
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
@@ -59,10 +59,8 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
if not self._output_format:
|
||||
return base_prompt
|
||||
|
||||
output_schema: str = (
|
||||
get_i18n()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=json.dumps(self._schema, indent=2))
|
||||
output_schema: str = I18N_DEFAULT.slice("formatted_task_instructions").format(
|
||||
output_format=json.dumps(self._schema, indent=2)
|
||||
)
|
||||
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
@@ -43,7 +43,6 @@ from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
@@ -52,7 +51,6 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
if TYPE_CHECKING:
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.crew import Crew
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
|
||||
def _validate_crew_ref(value: Any) -> Any:
|
||||
@@ -179,7 +177,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
agent_executor: An instance of the CrewAgentExecutor class.
|
||||
llm (Any): Language model that will run the agent.
|
||||
crew (Any): Crew to which the agent belongs.
|
||||
i18n (I18N): Internationalization settings.
|
||||
|
||||
cache_handler ([CacheHandler]): An instance of the CacheHandler class.
|
||||
tools_handler ([ToolsHandler]): An instance of the ToolsHandler class.
|
||||
max_tokens: Maximum number of tokens for the agent to generate in a response.
|
||||
@@ -269,9 +267,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
_serialize_crew_ref, return_type=str | None, when_used="always"
|
||||
),
|
||||
] = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(
|
||||
default_factory=get_i18n, description="Internationalization settings."
|
||||
)
|
||||
cache_handler: CacheHandler | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
@@ -342,19 +337,16 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Self:
|
||||
"""Restore an Agent from a checkpoint file."""
|
||||
def from_checkpoint(cls, config: CheckpointConfig) -> Self:
|
||||
"""Restore an Agent from a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Checkpoint configuration with ``restore_from`` set.
|
||||
"""
|
||||
from crewai.context import apply_execution_context
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
state = RuntimeState.from_checkpoint(config, context={"from_checkpoint": True})
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
if entity.execution_context is not None:
|
||||
@@ -363,7 +355,9 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
entity.agent_executor.agent = entity
|
||||
entity.agent_executor._resuming = True
|
||||
return entity
|
||||
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
|
||||
raise ValueError(
|
||||
f"No {cls.__name__} found in checkpoint: {config.restore_from}"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -389,7 +383,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if isinstance(tool, BaseTool):
|
||||
processed_tools.append(tool)
|
||||
elif all(hasattr(tool, attr) for attr in required_attrs):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -454,14 +447,12 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self) -> Self:
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
|
||||
# Set private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
@@ -470,7 +461,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
# Initialize security_config if not provided
|
||||
if self.security_config is None:
|
||||
self.security_config = SecurityConfig()
|
||||
|
||||
@@ -572,14 +562,11 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
"actions",
|
||||
}
|
||||
|
||||
# Copy llm
|
||||
existing_llm = shallow_copy(self.llm)
|
||||
copied_knowledge = shallow_copy(self.knowledge)
|
||||
copied_knowledge_storage = shallow_copy(self.knowledge_storage)
|
||||
# Properly copy knowledge sources if they exist
|
||||
existing_knowledge_sources = None
|
||||
if self.knowledge_sources:
|
||||
# Create a shared storage instance for all knowledge sources
|
||||
shared_storage = (
|
||||
self.knowledge_sources[0].storage if self.knowledge_sources else None
|
||||
)
|
||||
@@ -591,7 +578,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if hasattr(source, "model_copy")
|
||||
else shallow_copy(source)
|
||||
)
|
||||
# Ensure all copied sources use the same storage instance
|
||||
copied_source.storage = shared_storage
|
||||
existing_knowledge_sources.append(copied_source)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class BaseAgentExecutor(BaseModel):
|
||||
@@ -28,7 +27,6 @@ class BaseAgentExecutor(BaseModel):
|
||||
max_iter: int = Field(default=25)
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
_resuming: bool = PrivateAttr(default=False)
|
||||
_i18n: I18N | None = PrivateAttr(default=None)
|
||||
|
||||
def _save_to_memory(self, output: AgentFinish) -> None:
|
||||
"""Save task result to unified memory (memory or crew._memory)."""
|
||||
|
||||
@@ -4,8 +4,6 @@ import re
|
||||
from typing import Final
|
||||
|
||||
|
||||
# crewai.agents.parser constants
|
||||
|
||||
FINAL_ANSWER_ACTION: Final[str] = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE: Final[str] = (
|
||||
"I did it wrong. Invalid Format: I missed the 'Action:' after 'Thought:'. I will do right next, and don't use a tool I have already used.\n"
|
||||
|
||||
@@ -67,7 +67,7 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
@@ -135,9 +135,8 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._i18n = i18n or get_i18n()
|
||||
if not self.before_llm_call_hooks:
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
if not self.after_llm_call_hooks:
|
||||
@@ -297,7 +296,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Check if model supports native function calling
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
@@ -308,7 +306,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
if use_native_tools:
|
||||
return self._invoke_loop_native_tools()
|
||||
|
||||
# Fall back to ReAct text-based pattern
|
||||
return self._invoke_loop_react()
|
||||
|
||||
def _invoke_loop_react(self) -> AgentFinish:
|
||||
@@ -328,7 +325,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=PRINTER,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
@@ -349,7 +345,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
# breakpoint()
|
||||
if self.response_model is not None:
|
||||
try:
|
||||
if isinstance(answer, BaseModel):
|
||||
@@ -367,7 +362,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
text=answer,
|
||||
)
|
||||
except ValidationError:
|
||||
# If validation fails, convert BaseModel to JSON string for parsing
|
||||
answer_str = (
|
||||
answer.model_dump_json()
|
||||
if isinstance(answer, BaseModel)
|
||||
@@ -377,14 +371,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
else:
|
||||
# When no response_model, answer should be a string
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
@@ -401,7 +393,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
agent_action=formatted_answer,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
@@ -429,7 +420,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
@@ -438,7 +428,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
continue
|
||||
@@ -447,10 +436,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
# During the invoke loop, formatted_answer alternates between AgentAction
|
||||
# (when the agent is using tools) and eventually becomes AgentFinish
|
||||
# (when the agent reaches a final answer). This check confirms we've
|
||||
# reached a final answer and helps type checking understand this transition.
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
@@ -469,9 +454,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Convert tools to OpenAI schema format
|
||||
if not self.original_tools:
|
||||
# No tools available, fall back to simple LLM call
|
||||
return self._invoke_loop_native_no_tools()
|
||||
|
||||
openai_tools, available_functions, self._tool_name_mapping = (
|
||||
@@ -484,7 +467,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
None,
|
||||
printer=PRINTER,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
@@ -495,10 +477,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
# Call LLM with native tools
|
||||
# Pass available_functions=None so the LLM returns tool_calls
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = get_llm_response(
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
@@ -513,32 +491,26 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
# Check if the response is a list of tool calls
|
||||
if (
|
||||
isinstance(answer, list)
|
||||
and answer
|
||||
and self._is_tool_call_list(answer)
|
||||
):
|
||||
# Handle tool calls - execute tools and add results to messages
|
||||
tool_finish = self._handle_native_tool_calls(
|
||||
answer, available_functions
|
||||
)
|
||||
# If tool has result_as_answer=True, return immediately
|
||||
if tool_finish is not None:
|
||||
return tool_finish
|
||||
# Continue loop to let LLM analyze results and decide next steps
|
||||
continue
|
||||
|
||||
# Text or other response - handle as potential final answer
|
||||
if isinstance(answer, str):
|
||||
# Text response - this is the final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(answer) # Save final answer to messages
|
||||
self._append_message(answer)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -554,14 +526,13 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(str(answer)) # Save final answer to messages
|
||||
self._append_message(str(answer))
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -575,7 +546,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
continue
|
||||
@@ -633,12 +603,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
if not response:
|
||||
return False
|
||||
first_item = response[0]
|
||||
# OpenAI-style
|
||||
if hasattr(first_item, "function") or (
|
||||
isinstance(first_item, dict) and "function" in first_item
|
||||
):
|
||||
return True
|
||||
# Anthropic-style (object with attributes)
|
||||
if (
|
||||
hasattr(first_item, "type")
|
||||
and getattr(first_item, "type", None) == "tool_use"
|
||||
@@ -646,14 +614,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
return True
|
||||
if hasattr(first_item, "name") and hasattr(first_item, "input"):
|
||||
return True
|
||||
# Bedrock-style (dict with name and input keys)
|
||||
if (
|
||||
isinstance(first_item, dict)
|
||||
and "name" in first_item
|
||||
and "input" in first_item
|
||||
):
|
||||
return True
|
||||
# Gemini-style
|
||||
if hasattr(first_item, "function_call") and first_item.function_call:
|
||||
return True
|
||||
return False
|
||||
@@ -712,8 +678,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
for _, func_name, _ in parsed_calls
|
||||
)
|
||||
|
||||
# Preserve historical sequential behavior for result_as_answer batches.
|
||||
# Also avoid threading around usage counters for max_usage_count tools.
|
||||
if has_result_as_answer_in_batch or has_max_usage_count_in_batch:
|
||||
logger.debug(
|
||||
"Skipping parallel native execution because batch includes result_as_answer or max_usage_count tool"
|
||||
@@ -771,7 +735,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
if tool_finish:
|
||||
return tool_finish
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_prompt = I18N_DEFAULT.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
@@ -779,7 +743,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self.messages.append(reasoning_message)
|
||||
return None
|
||||
|
||||
# Sequential behavior: process only first tool call, then force reflection.
|
||||
call_id, func_name, func_args = parsed_calls[0]
|
||||
self._append_assistant_tool_calls_message([(call_id, func_name, func_args)])
|
||||
|
||||
@@ -795,7 +758,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
if tool_finish:
|
||||
return tool_finish
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_prompt = I18N_DEFAULT.slice("post_tool_reasoning")
|
||||
reasoning_message = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
@@ -833,7 +796,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
func_name = sanitize_tool_name(
|
||||
func_info.get("name", "") or tool_call.get("name", "")
|
||||
)
|
||||
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
|
||||
func_args = func_info.get("arguments") or tool_call.get("input", {})
|
||||
return call_id, func_name, func_args
|
||||
return None
|
||||
|
||||
@@ -1170,7 +1133,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=PRINTER,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
@@ -1209,7 +1171,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
text=answer,
|
||||
)
|
||||
except ValidationError:
|
||||
# If validation fails, convert BaseModel to JSON string for parsing
|
||||
answer_str = (
|
||||
answer.model_dump_json()
|
||||
if isinstance(answer, BaseModel)
|
||||
@@ -1219,7 +1180,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
else:
|
||||
# When no response_model, answer should be a string
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
@@ -1242,7 +1202,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
agent_action=formatted_answer,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
@@ -1278,7 +1237,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
continue
|
||||
@@ -1318,7 +1276,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
None,
|
||||
printer=PRINTER,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
@@ -1329,10 +1286,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
# Call LLM with native tools
|
||||
# Pass available_functions=None so the LLM returns tool_calls
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = await aget_llm_response(
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
@@ -1346,32 +1299,26 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
# Check if the response is a list of tool calls
|
||||
if (
|
||||
isinstance(answer, list)
|
||||
and answer
|
||||
and self._is_tool_call_list(answer)
|
||||
):
|
||||
# Handle tool calls - execute tools and add results to messages
|
||||
tool_finish = self._handle_native_tool_calls(
|
||||
answer, available_functions
|
||||
)
|
||||
# If tool has result_as_answer=True, return immediately
|
||||
if tool_finish is not None:
|
||||
return tool_finish
|
||||
# Continue loop to let LLM analyze results and decide next steps
|
||||
continue
|
||||
|
||||
# Text or other response - handle as potential final answer
|
||||
if isinstance(answer, str):
|
||||
# Text response - this is the final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
await self._ainvoke_step_callback(formatted_answer)
|
||||
self._append_message(answer) # Save final answer to messages
|
||||
self._append_message(answer)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -1387,14 +1334,13 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
await self._ainvoke_step_callback(formatted_answer)
|
||||
self._append_message(str(answer)) # Save final answer to messages
|
||||
self._append_message(str(answer))
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -1408,7 +1354,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
messages=self.messages,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
continue
|
||||
@@ -1466,8 +1411,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Updated action or final answer.
|
||||
"""
|
||||
# Special case for add_image_tool
|
||||
add_image_tool = self._i18n.tools("add_image")
|
||||
add_image_tool = I18N_DEFAULT.tools("add_image")
|
||||
if (
|
||||
isinstance(add_image_tool, dict)
|
||||
and formatted_answer.tool.casefold().strip()
|
||||
@@ -1586,17 +1530,14 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
|
||||
training_data = training_handler.load() or {}
|
||||
|
||||
# Initialize or retrieve agent's training data
|
||||
agent_training_data = training_data.get(agent_id, {})
|
||||
|
||||
if human_feedback is not None:
|
||||
# Save initial output and human feedback
|
||||
agent_training_data[train_iteration] = {
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
}
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
@@ -1610,7 +1551,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
)
|
||||
return
|
||||
|
||||
# Update the training data and save
|
||||
training_data[agent_id] = agent_training_data
|
||||
training_handler.save(training_data)
|
||||
|
||||
@@ -1673,5 +1613,5 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Formatted message dict.
|
||||
"""
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
I18N_DEFAULT.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@@ -19,10 +19,7 @@ from crewai.agents.constants import (
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
|
||||
_I18N = get_i18n()
|
||||
from crewai.utilities.i18n import I18N_DEFAULT as _I18N
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,11 +94,8 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
|
||||
if includes_answer:
|
||||
final_answer = text.split(FINAL_ANSWER_ACTION)[-1].strip()
|
||||
# Check whether the final answer ends with triple backticks.
|
||||
if final_answer.endswith("```"):
|
||||
# Count occurrences of triple backticks in the final answer.
|
||||
count = final_answer.count("```")
|
||||
# If count is odd then it's an unmatched trailing set; remove it.
|
||||
if count % 2 != 0:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought=thought, output=final_answer, text=text)
|
||||
@@ -149,7 +143,6 @@ def _extract_thought(text: str) -> str:
|
||||
if thought_index == -1:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
return thought.replace("```", "").strip()
|
||||
|
||||
|
||||
@@ -174,18 +167,9 @@ def _safe_repair_json(tool_input: str) -> str:
|
||||
Returns:
|
||||
The repaired JSON string or original if repair fails.
|
||||
"""
|
||||
# Skip repair if the input starts and ends with square brackets
|
||||
# Explanation: The JSON parser has issues handling inputs that are enclosed in square brackets ('[]').
|
||||
# These are typically valid JSON arrays or strings that do not require repair. Attempting to repair such inputs
|
||||
# might lead to unintended alterations, such as wrapping the entire input in additional layers or modifying
|
||||
# the structure in a way that changes its meaning. By skipping the repair for inputs that start and end with
|
||||
# square brackets, we preserve the integrity of these valid JSON structures and avoid unnecessary modifications.
|
||||
if tool_input.startswith("[") and tool_input.endswith("]"):
|
||||
return tool_input
|
||||
|
||||
# Before repair, handle common LLM issues:
|
||||
# 1. Replace """ with " to avoid JSON parser errors
|
||||
|
||||
tool_input = tool_input.replace('"""', '"')
|
||||
|
||||
result = repair_json(tool_input)
|
||||
|
||||
@@ -23,7 +23,7 @@ from crewai.events.types.observation_events import (
|
||||
StepObservationStartedEvent,
|
||||
)
|
||||
from crewai.utilities.agent_utils import extract_task_section
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.planning_types import StepObservation, TodoItem
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -64,7 +64,6 @@ class PlannerObserver:
|
||||
self.task = task
|
||||
self.kickoff_input = kickoff_input
|
||||
self.llm = self._resolve_llm()
|
||||
self._i18n: I18N = get_i18n()
|
||||
|
||||
def _resolve_llm(self) -> Any:
|
||||
"""Resolve which LLM to use for observation/planning.
|
||||
@@ -84,10 +83,6 @@ class PlannerObserver:
|
||||
return create_llm(config.llm)
|
||||
return self.agent.llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def observe(
|
||||
self,
|
||||
completed_step: TodoItem,
|
||||
@@ -183,9 +178,6 @@ class PlannerObserver:
|
||||
),
|
||||
)
|
||||
|
||||
# Don't force a full replan — the step may have succeeded even if the
|
||||
# observer LLM failed to parse the result. Defaulting to "continue" is
|
||||
# far less disruptive than wiping the entire plan on every observer error.
|
||||
return StepObservation(
|
||||
step_completed_successfully=True,
|
||||
key_information_learned="",
|
||||
@@ -222,10 +214,6 @@ class PlannerObserver:
|
||||
|
||||
return remaining_todos
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Message building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_observation_messages(
|
||||
self,
|
||||
completed_step: TodoItem,
|
||||
@@ -240,15 +228,11 @@ class PlannerObserver:
|
||||
task_desc = self.task.description or ""
|
||||
task_goal = self.task.expected_output or ""
|
||||
elif self.kickoff_input:
|
||||
# Standalone kickoff path — no Task object, but we have the raw input.
|
||||
# Extract just the ## Task section so the observer sees the actual goal,
|
||||
# not the full enriched instruction with env/tools/verification noise.
|
||||
task_desc = extract_task_section(self.kickoff_input)
|
||||
task_goal = "Complete the task successfully"
|
||||
|
||||
system_prompt = self._i18n.retrieve("planning", "observation_system_prompt")
|
||||
system_prompt = I18N_DEFAULT.retrieve("planning", "observation_system_prompt")
|
||||
|
||||
# Build context of what's been done
|
||||
completed_summary = ""
|
||||
if all_completed:
|
||||
completed_lines = []
|
||||
@@ -262,7 +246,6 @@ class PlannerObserver:
|
||||
completed_lines
|
||||
)
|
||||
|
||||
# Build remaining plan
|
||||
remaining_summary = ""
|
||||
if remaining_todos:
|
||||
remaining_lines = [
|
||||
@@ -273,7 +256,9 @@ class PlannerObserver:
|
||||
remaining_lines
|
||||
)
|
||||
|
||||
user_prompt = self._i18n.retrieve("planning", "observation_user_prompt").format(
|
||||
user_prompt = I18N_DEFAULT.retrieve(
|
||||
"planning", "observation_user_prompt"
|
||||
).format(
|
||||
task_description=task_desc,
|
||||
task_goal=task_goal,
|
||||
completed_summary=completed_summary,
|
||||
@@ -305,17 +290,14 @@ class PlannerObserver:
|
||||
if isinstance(response, StepObservation):
|
||||
return response
|
||||
|
||||
# JSON string path — most common miss before this fix
|
||||
if isinstance(response, str):
|
||||
text = response.strip()
|
||||
try:
|
||||
return StepObservation.model_validate_json(text)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
# Some LLMs wrap the JSON in markdown fences
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
# Strip first and last lines (``` markers)
|
||||
inner = "\n".join(
|
||||
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||
)
|
||||
@@ -324,14 +306,12 @@ class PlannerObserver:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
# Dict path
|
||||
if isinstance(response, dict):
|
||||
try:
|
||||
return StepObservation.model_validate(response)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
# Last resort — log what we got so it's diagnosable
|
||||
logger.warning(
|
||||
"Could not parse observation response (type=%s). "
|
||||
"Falling back to default failure observation. Preview: %.200s",
|
||||
|
||||
@@ -38,7 +38,7 @@ from crewai.utilities.agent_utils import (
|
||||
process_llm_response,
|
||||
setup_native_tools,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.planning_types import TodoItem
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
|
||||
@@ -81,7 +81,7 @@ class StepExecutor:
|
||||
function_calling_llm: Optional separate LLM for function calling.
|
||||
request_within_rpm_limit: Optional RPM limit function.
|
||||
callbacks: Optional list of callbacks.
|
||||
i18n: Optional i18n instance.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -96,7 +96,6 @@ class StepExecutor:
|
||||
function_calling_llm: BaseLLM | None = None,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
@@ -108,9 +107,7 @@ class StepExecutor:
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.callbacks = callbacks or []
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
|
||||
# Native tool support — set up once
|
||||
self._use_native_tools = check_native_tool_support(
|
||||
self.llm, self.original_tools
|
||||
)
|
||||
@@ -123,10 +120,6 @@ class StepExecutor:
|
||||
_,
|
||||
) = setup_native_tools(self.original_tools)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(
|
||||
self,
|
||||
todo: TodoItem,
|
||||
@@ -192,10 +185,6 @@ class StepExecutor:
|
||||
execution_time=elapsed,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Message building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_isolated_messages(
|
||||
self, todo: TodoItem, context: StepExecutionContext
|
||||
) -> list[LLMMessage]:
|
||||
@@ -221,14 +210,14 @@ class StepExecutor:
|
||||
tools_section = ""
|
||||
if self.tools and not self._use_native_tools:
|
||||
tool_names = ", ".join(sanitize_tool_name(t.name) for t in self.tools)
|
||||
tools_section = self._i18n.retrieve(
|
||||
tools_section = I18N_DEFAULT.retrieve(
|
||||
"planning", "step_executor_tools_section"
|
||||
).format(tool_names=tool_names)
|
||||
elif self.tools:
|
||||
tool_names = ", ".join(sanitize_tool_name(t.name) for t in self.tools)
|
||||
tools_section = f"\n\nAvailable tools: {tool_names}"
|
||||
|
||||
return self._i18n.retrieve("planning", "step_executor_system_prompt").format(
|
||||
return I18N_DEFAULT.retrieve("planning", "step_executor_system_prompt").format(
|
||||
role=role,
|
||||
backstory=backstory,
|
||||
goal=goal,
|
||||
@@ -239,15 +228,11 @@ class StepExecutor:
|
||||
"""Build the user prompt for this specific step."""
|
||||
parts: list[str] = []
|
||||
|
||||
# Include overall task context so the executor knows the full goal and
|
||||
# required output format/location — critical for knowing WHAT to produce.
|
||||
# We extract only the task body (not tool instructions or verification
|
||||
# sections) to avoid duplicating directives already in the system prompt.
|
||||
if context.task_description:
|
||||
task_section = extract_task_section(context.task_description)
|
||||
if task_section:
|
||||
parts.append(
|
||||
self._i18n.retrieve(
|
||||
I18N_DEFAULT.retrieve(
|
||||
"planning", "step_executor_task_context"
|
||||
).format(
|
||||
task_context=task_section,
|
||||
@@ -255,38 +240,35 @@ class StepExecutor:
|
||||
)
|
||||
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_user_prompt").format(
|
||||
I18N_DEFAULT.retrieve("planning", "step_executor_user_prompt").format(
|
||||
step_description=todo.description,
|
||||
)
|
||||
)
|
||||
|
||||
if todo.tool_to_use:
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_suggested_tool").format(
|
||||
I18N_DEFAULT.retrieve(
|
||||
"planning", "step_executor_suggested_tool"
|
||||
).format(
|
||||
tool_to_use=todo.tool_to_use,
|
||||
)
|
||||
)
|
||||
|
||||
# Include dependency results (final results only, no traces)
|
||||
if context.dependency_results:
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_context_header")
|
||||
I18N_DEFAULT.retrieve("planning", "step_executor_context_header")
|
||||
)
|
||||
for step_num, result in sorted(context.dependency_results.items()):
|
||||
parts.append(
|
||||
self._i18n.retrieve(
|
||||
I18N_DEFAULT.retrieve(
|
||||
"planning", "step_executor_context_entry"
|
||||
).format(step_number=step_num, result=result)
|
||||
)
|
||||
|
||||
parts.append(self._i18n.retrieve("planning", "step_executor_complete_step"))
|
||||
parts.append(I18N_DEFAULT.retrieve("planning", "step_executor_complete_step"))
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Multi-turn execution loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_text_parsed(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -306,7 +288,6 @@ class StepExecutor:
|
||||
last_tool_result = ""
|
||||
|
||||
for _ in range(max_step_iterations):
|
||||
# Check step timeout
|
||||
if step_timeout and start_time:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= step_timeout:
|
||||
@@ -331,17 +312,12 @@ class StepExecutor:
|
||||
tool_calls_made.append(formatted.tool)
|
||||
tool_result = self._execute_text_tool_with_events(formatted)
|
||||
last_tool_result = tool_result
|
||||
# Append the assistant's reasoning + action, then the observation.
|
||||
# _build_observation_message handles vision sentinels so the LLM
|
||||
# receives an image content block instead of raw base64 text.
|
||||
messages.append({"role": "assistant", "content": answer_str})
|
||||
messages.append(self._build_observation_message(tool_result))
|
||||
continue
|
||||
|
||||
# Raw text response with no Final Answer marker — treat as done
|
||||
return answer_str
|
||||
|
||||
# Max iterations reached — return the last tool result we accumulated
|
||||
return last_tool_result
|
||||
|
||||
def _execute_text_tool_with_events(self, formatted: AgentAction) -> str:
|
||||
@@ -375,7 +351,6 @@ class StepExecutor:
|
||||
agent_action=formatted,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
@@ -430,10 +405,6 @@ class StepExecutor:
|
||||
return {"input": stripped_input}
|
||||
return {"input": str(tool_input)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Vision support
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _parse_vision_sentinel(raw: str) -> tuple[str, str] | None:
|
||||
"""Parse a VISION_IMAGE sentinel into (media_type, base64_data), or None."""
|
||||
@@ -518,7 +489,6 @@ class StepExecutor:
|
||||
accumulated_results: list[str] = []
|
||||
|
||||
for _ in range(max_step_iterations):
|
||||
# Check step timeout
|
||||
if step_timeout and start_time:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= step_timeout:
|
||||
@@ -542,19 +512,14 @@ class StepExecutor:
|
||||
return answer.model_dump_json()
|
||||
|
||||
if isinstance(answer, list) and answer and is_tool_call_list(answer):
|
||||
# _execute_native_tool_calls appends assistant + tool messages
|
||||
# to `messages` as a side-effect, so the next LLM call will
|
||||
# see the full conversation history including tool outputs.
|
||||
result = self._execute_native_tool_calls(
|
||||
answer, messages, tool_calls_made
|
||||
)
|
||||
accumulated_results.append(result)
|
||||
continue
|
||||
|
||||
# Text answer → LLM decided the step is done
|
||||
return str(answer)
|
||||
|
||||
# Max iterations reached — return everything we accumulated
|
||||
return "\n".join(filter(None, accumulated_results))
|
||||
|
||||
def _execute_native_tool_calls(
|
||||
@@ -600,9 +565,6 @@ class StepExecutor:
|
||||
parsed = self._parse_vision_sentinel(raw_content)
|
||||
if parsed:
|
||||
media_type, b64_data = parsed
|
||||
# Replace the sentinel with a standard image_url content block.
|
||||
# Each provider's _format_messages handles conversion to
|
||||
# its native format (e.g. Anthropic image blocks).
|
||||
modified: LLMMessage = cast(
|
||||
LLMMessage, dict(call_result.tool_message)
|
||||
)
|
||||
|
||||
@@ -2,16 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
|
||||
|
||||
|
||||
_SQLITE_MAGIC = b"SQLite format 3\x00"
|
||||
|
||||
_SELECT_ALL = """
|
||||
@@ -33,6 +37,45 @@ ORDER BY rowid DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_DELETE_OLDER_THAN = """
|
||||
DELETE FROM checkpoints
|
||||
WHERE created_at < ?
|
||||
"""
|
||||
|
||||
_DELETE_KEEP_N = """
|
||||
DELETE FROM checkpoints WHERE rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
|
||||
)
|
||||
"""
|
||||
|
||||
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
|
||||
|
||||
_SELECT_LIKE = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
WHERE id LIKE ?
|
||||
ORDER BY rowid DESC
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_DIR = "./.checkpoints"
|
||||
_DEFAULT_DB = "./.checkpoints.db"
|
||||
|
||||
|
||||
def _detect_location(location: str) -> str:
|
||||
"""Resolve the default checkpoint location.
|
||||
|
||||
When the caller passes the default directory path, check whether a
|
||||
SQLite database exists at the conventional ``.db`` path and prefer it.
|
||||
"""
|
||||
if (
|
||||
location == _DEFAULT_DIR
|
||||
and not os.path.exists(_DEFAULT_DIR)
|
||||
and os.path.exists(_DEFAULT_DB)
|
||||
):
|
||||
return _DEFAULT_DB
|
||||
return location
|
||||
|
||||
|
||||
def _is_sqlite(path: str) -> bool:
|
||||
"""Check if a file is a SQLite database by reading its magic bytes."""
|
||||
@@ -52,13 +95,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
nodes = data.get("event_record", {}).get("nodes", {})
|
||||
event_count = len(nodes)
|
||||
|
||||
trigger_event = None
|
||||
if nodes:
|
||||
last_node = max(
|
||||
nodes.values(),
|
||||
key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
|
||||
)
|
||||
trigger_event = last_node.get("event", {}).get("type")
|
||||
trigger_event = data.get("trigger")
|
||||
|
||||
parsed_entities: list[dict[str, Any]] = []
|
||||
for entity in entities:
|
||||
@@ -76,16 +113,47 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
{
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("output") is not None,
|
||||
"output": (t.get("output") or {}).get("raw", ""),
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
parsed_entities.append(info)
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for entity in entities:
|
||||
cp_inputs = entity.get("checkpoint_inputs")
|
||||
if isinstance(cp_inputs, dict) and cp_inputs:
|
||||
inputs = dict(cp_inputs)
|
||||
break
|
||||
|
||||
for entity in entities:
|
||||
for task in entity.get("tasks", []):
|
||||
for field in (
|
||||
"checkpoint_original_description",
|
||||
"checkpoint_original_expected_output",
|
||||
):
|
||||
text = task.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
for agent in entity.get("agents", []):
|
||||
for field in ("role", "goal", "backstory"):
|
||||
text = agent.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
|
||||
branch = data.get("branch", "main")
|
||||
parent_id = data.get("parent_id")
|
||||
|
||||
return {
|
||||
"source": source,
|
||||
"event_count": event_count,
|
||||
"trigger": trigger_event,
|
||||
"entities": parsed_entities,
|
||||
"branch": branch,
|
||||
"parent_id": parent_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
|
||||
@@ -125,9 +193,11 @@ def _entity_summary(entities: list[dict[str, Any]]) -> str:
|
||||
|
||||
|
||||
def _list_json(location: str) -> list[dict[str, Any]]:
|
||||
pattern = os.path.join(location, "*.json")
|
||||
pattern = os.path.join(location, "**", "*.json")
|
||||
results = []
|
||||
for path in sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True):
|
||||
for path in sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
):
|
||||
name = os.path.basename(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
@@ -144,8 +214,10 @@ def _list_json(location: str) -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
def _info_json_latest(location: str) -> dict[str, Any] | None:
|
||||
pattern = os.path.join(location, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True)
|
||||
pattern = os.path.join(location, "**", "*.json")
|
||||
files = sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
if not files:
|
||||
return None
|
||||
path = files[0]
|
||||
@@ -189,6 +261,7 @@ def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
|
||||
"entities": [],
|
||||
"source": checkpoint_id,
|
||||
}
|
||||
meta["db"] = db_path
|
||||
results.append(meta)
|
||||
return results
|
||||
|
||||
@@ -209,6 +282,8 @@ def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
|
||||
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
|
||||
if not row:
|
||||
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
cid, created_at, raw = row
|
||||
@@ -311,6 +386,10 @@ def _print_info(meta: dict[str, Any]) -> None:
|
||||
trigger = meta.get("trigger")
|
||||
if trigger:
|
||||
click.echo(f"Trigger: {trigger}")
|
||||
click.echo(f"Branch: {meta.get('branch', 'main')}")
|
||||
parent_id = meta.get("parent_id")
|
||||
if parent_id:
|
||||
click.echo(f"Parent: {parent_id}")
|
||||
|
||||
for ent in meta.get("entities", []):
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
@@ -327,3 +406,287 @@ def _print_info(meta: dict[str, Any]) -> None:
|
||||
if len(desc) > 70:
|
||||
desc = desc[:67] + "..."
|
||||
click.echo(f" {i + 1}. [{status}] {desc}")
|
||||
|
||||
|
||||
def _resolve_checkpoint(
|
||||
location: str, checkpoint_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
if _is_sqlite(location):
|
||||
if checkpoint_id:
|
||||
return _info_sqlite_id(location, checkpoint_id)
|
||||
return _info_sqlite_latest(location)
|
||||
if os.path.isdir(location):
|
||||
if checkpoint_id:
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
|
||||
_json_provider: JsonProvider = JsonProvider()
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
all_files: list[str] = glob.glob(pattern, recursive=True)
|
||||
matches: list[str] = [
|
||||
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
|
||||
]
|
||||
matches.sort(key=os.path.getmtime, reverse=True)
|
||||
if matches:
|
||||
return _info_json_file(matches[0])
|
||||
return None
|
||||
return _info_json_latest(location)
|
||||
if os.path.isfile(location):
|
||||
return _info_json_file(location)
|
||||
return None
|
||||
|
||||
|
||||
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
|
||||
for ent in meta.get("entities", []):
|
||||
if ent.get("type") == "flow":
|
||||
return "flow"
|
||||
return "crew"
|
||||
|
||||
|
||||
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
|
||||
import asyncio
|
||||
|
||||
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
|
||||
if meta is None:
|
||||
if checkpoint_id:
|
||||
click.echo(f"Checkpoint not found: {checkpoint_id}")
|
||||
else:
|
||||
click.echo(f"No checkpoints found in {location}")
|
||||
return
|
||||
|
||||
restore_path: str = meta.get("path") or meta.get("source", "")
|
||||
if meta.get("db"):
|
||||
restore_path = f"{meta['db']}#{meta['name']}"
|
||||
|
||||
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
|
||||
_print_info(meta)
|
||||
click.echo()
|
||||
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
|
||||
entity_type: str = _entity_type_from_meta(meta)
|
||||
inputs: dict[str, Any] | None = meta.get("inputs") or None
|
||||
|
||||
if entity_type == "flow":
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
flow = Flow.from_checkpoint(config)
|
||||
result = asyncio.run(flow.kickoff_async(inputs=inputs))
|
||||
else:
|
||||
from crewai.crew import Crew
|
||||
|
||||
crew = Crew.from_checkpoint(config)
|
||||
result = asyncio.run(crew.akickoff(inputs=inputs))
|
||||
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
tasks: list[dict[str, Any]] = []
|
||||
for ent in meta.get("entities", []):
|
||||
tasks.extend(
|
||||
{
|
||||
"entity": ent.get("name", "unnamed"),
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("completed", False),
|
||||
"output": t.get("output", ""),
|
||||
}
|
||||
for t in ent.get("tasks", [])
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
|
||||
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
|
||||
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
|
||||
|
||||
if meta1 is None:
|
||||
click.echo(f"Checkpoint not found: {id1}")
|
||||
return
|
||||
if meta2 is None:
|
||||
click.echo(f"Checkpoint not found: {id2}")
|
||||
return
|
||||
|
||||
name1: str = meta1.get("name", id1)
|
||||
name2: str = meta2.get("name", id2)
|
||||
|
||||
click.echo(f"--- {name1}")
|
||||
click.echo(f"+++ {name2}")
|
||||
click.echo()
|
||||
|
||||
fields: list[tuple[str, str]] = [
|
||||
("Time", "ts"),
|
||||
("Branch", "branch"),
|
||||
("Trigger", "trigger"),
|
||||
("Events", "event_count"),
|
||||
]
|
||||
for label, key in fields:
|
||||
v1: str = str(meta1.get(key, ""))
|
||||
v2: str = str(meta2.get(key, ""))
|
||||
if v1 != v2:
|
||||
click.echo(f" {label}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
inputs1: dict[str, Any] = meta1.get("inputs", {})
|
||||
inputs2: dict[str, Any] = meta2.get("inputs", {})
|
||||
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
|
||||
changed_inputs: list[tuple[str, Any, Any]] = [
|
||||
(k, inputs1.get(k, ""), inputs2.get(k, ""))
|
||||
for k in all_keys
|
||||
if inputs1.get(k) != inputs2.get(k)
|
||||
]
|
||||
if changed_inputs:
|
||||
click.echo("\n Inputs:")
|
||||
for key, v1, v2 in changed_inputs:
|
||||
click.echo(f" {key}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
|
||||
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
|
||||
|
||||
max_tasks: int = max(len(tasks1), len(tasks2))
|
||||
if max_tasks == 0:
|
||||
return
|
||||
|
||||
click.echo("\n Tasks:")
|
||||
for i in range(max_tasks):
|
||||
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
|
||||
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
|
||||
|
||||
if t1 is None:
|
||||
desc: str = t2["description"][:60] if t2 else ""
|
||||
click.echo(f" + {i + 1}. [new] {desc}")
|
||||
continue
|
||||
if t2 is None:
|
||||
desc = t1["description"][:60]
|
||||
click.echo(f" - {i + 1}. [removed] {desc}")
|
||||
continue
|
||||
|
||||
desc = str(t1["description"][:60])
|
||||
s1: str = "done" if t1["completed"] else "pending"
|
||||
s2: str = "done" if t2["completed"] else "pending"
|
||||
|
||||
if s1 != s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
click.echo(f" status: {s1} -> {s2}")
|
||||
|
||||
out1: str = (t1.get("output") or "").strip()
|
||||
out2: str = (t2.get("output") or "").strip()
|
||||
if out1 != out2:
|
||||
if s1 == s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
preview1: str = (
|
||||
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
|
||||
)
|
||||
preview2: str = (
|
||||
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
|
||||
)
|
||||
click.echo(" output:")
|
||||
click.echo(f" - {preview1}")
|
||||
click.echo(f" + {preview2}")
|
||||
|
||||
|
||||
def _parse_duration(value: str) -> timedelta:
|
||||
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
|
||||
if not match:
|
||||
raise click.BadParameter(
|
||||
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
|
||||
)
|
||||
amount: int = int(match.group(1))
|
||||
unit: str = match.group(2)
|
||||
if unit == "d":
|
||||
return timedelta(days=amount)
|
||||
if unit == "h":
|
||||
return timedelta(hours=amount)
|
||||
return timedelta(minutes=amount)
|
||||
|
||||
|
||||
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
files: list[str] = sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
if not files:
|
||||
return 0
|
||||
|
||||
to_delete: set[str] = set()
|
||||
|
||||
if keep is not None and len(files) > keep:
|
||||
to_delete.update(files[keep:])
|
||||
|
||||
if older_than is not None:
|
||||
cutoff: datetime = datetime.now(timezone.utc) - older_than
|
||||
for path in files:
|
||||
mtime: datetime = datetime.fromtimestamp(
|
||||
os.path.getmtime(path), tz=timezone.utc
|
||||
)
|
||||
if mtime < cutoff:
|
||||
to_delete.add(path)
|
||||
|
||||
deleted: int = 0
|
||||
for path in to_delete:
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted += 1
|
||||
except OSError: # noqa: PERF203
|
||||
pass
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
|
||||
if dirpath != location and not filenames and not dirnames:
|
||||
try:
|
||||
os.rmdir(dirpath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
deleted: int = 0
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
if older_than is not None:
|
||||
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
|
||||
"%Y%m%dT%H%M%S"
|
||||
)
|
||||
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
if keep is not None:
|
||||
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
|
||||
def prune_checkpoints(
|
||||
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
|
||||
) -> None:
|
||||
if keep is None and older_than is None:
|
||||
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
|
||||
return
|
||||
|
||||
duration: timedelta | None = _parse_duration(older_than) if older_than else None
|
||||
|
||||
deleted: int
|
||||
if _is_sqlite(location):
|
||||
if dry_run:
|
||||
with sqlite3.connect(location) as conn:
|
||||
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
|
||||
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_sqlite(location, keep, duration)
|
||||
elif os.path.isdir(location):
|
||||
if dry_run:
|
||||
files: list[str] = glob.glob(
|
||||
os.path.join(location, "**", "*.json"), recursive=True
|
||||
)
|
||||
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_json(location, keep, duration)
|
||||
else:
|
||||
click.echo(f"Not a directory or SQLite database: {location}")
|
||||
return
|
||||
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")
|
||||
|
||||
686
lib/crewai/src/crewai/cli/checkpoint_tui.py
Normal file
686
lib/crewai/src/crewai/cli/checkpoint_tui.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""Textual TUI for browsing checkpoint files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Static,
|
||||
TextArea,
|
||||
Tree,
|
||||
)
|
||||
|
||||
from crewai.cli.checkpoint_cli import (
|
||||
_format_size,
|
||||
_is_sqlite,
|
||||
_list_json,
|
||||
_list_sqlite,
|
||||
)
|
||||
|
||||
|
||||
_PRIMARY = "#eb6658"
|
||||
_SECONDARY = "#1F7982"
|
||||
_TERTIARY = "#ffffff"
|
||||
_DIM = "#888888"
|
||||
_BG_DARK = "#0d1117"
|
||||
_BG_PANEL = "#161b22"
|
||||
|
||||
|
||||
def _load_entries(location: str) -> list[dict[str, Any]]:
|
||||
if _is_sqlite(location):
|
||||
return _list_sqlite(location)
|
||||
return _list_json(location)
|
||||
|
||||
|
||||
def _short_id(name: str) -> str:
|
||||
"""Shorten a checkpoint name for tree display."""
|
||||
if len(name) > 30:
|
||||
return name[:27] + "..."
|
||||
return name
|
||||
|
||||
|
||||
def _entry_id(entry: dict[str, Any]) -> str:
|
||||
"""Normalize an entry's name into its checkpoint ID.
|
||||
|
||||
JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs
|
||||
are already ``{ts}_{uuid}``. This strips the JSON suffix so
|
||||
fork-parent lookups work in both providers.
|
||||
"""
|
||||
name = str(entry.get("name", ""))
|
||||
if name.endswith(".json"):
|
||||
name = name[: -len(".json")]
|
||||
idx = name.find("_p-")
|
||||
if idx != -1:
|
||||
name = name[:idx]
|
||||
return name
|
||||
|
||||
|
||||
def _build_entity_header(ent: dict[str, Any]) -> str:
|
||||
"""Build rich text header for an entity (progress bar only)."""
|
||||
lines: list[str] = []
|
||||
tasks = ent.get("tasks")
|
||||
if isinstance(tasks, list):
|
||||
completed = ent.get("tasks_completed", 0)
|
||||
total = ent.get("tasks_total", 0)
|
||||
pct = int(completed / total * 100) if total else 0
|
||||
bar_len = 20
|
||||
filled = int(bar_len * completed / total) if total else 0
|
||||
bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
|
||||
lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Return type: (location, action, inputs, task_output_overrides, entity_type)
|
||||
_TuiResult = (
|
||||
tuple[
|
||||
str,
|
||||
str,
|
||||
dict[str, Any] | None,
|
||||
dict[int, str] | None,
|
||||
Literal["crew", "flow"],
|
||||
]
|
||||
| None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointTUI(App[_TuiResult]):
|
||||
"""TUI to browse and inspect checkpoints.
|
||||
|
||||
Returns ``(location, action, inputs, task_overrides, entity_type)``
|
||||
where action is ``"resume"`` or ``"fork"``, inputs is a parsed dict
|
||||
or ``None``, and entity_type is ``"crew"`` or ``"flow"``;
|
||||
or ``None`` if the user quit without selecting.
|
||||
"""
|
||||
|
||||
TITLE = "CrewAI Checkpoints"
|
||||
|
||||
CSS = f"""
|
||||
Screen {{
|
||||
background: {_BG_DARK};
|
||||
}}
|
||||
Header {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer > .footer-key--key {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#main-layout {{
|
||||
height: 1fr;
|
||||
}}
|
||||
#tree-panel {{
|
||||
width: 45%;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_SECONDARY};
|
||||
padding: 0 1;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#tree-panel:focus-within {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#detail-container {{
|
||||
width: 55%;
|
||||
height: 1fr;
|
||||
}}
|
||||
#detail-scroll {{
|
||||
height: 1fr;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_SECONDARY};
|
||||
padding: 1 2;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#detail-scroll:focus-within {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#detail-header {{
|
||||
margin-bottom: 1;
|
||||
}}
|
||||
#status {{
|
||||
height: 1;
|
||||
padding: 0 2;
|
||||
color: {_DIM};
|
||||
}}
|
||||
#inputs-section {{
|
||||
display: none;
|
||||
height: auto;
|
||||
max-height: 8;
|
||||
padding: 0 1;
|
||||
}}
|
||||
#inputs-section.visible {{
|
||||
display: block;
|
||||
}}
|
||||
#inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row {{
|
||||
height: 3;
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row Static {{
|
||||
width: auto;
|
||||
min-width: 12;
|
||||
padding: 1 1 0 0;
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
.input-row Input {{
|
||||
width: 1fr;
|
||||
}}
|
||||
#no-inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
#action-buttons {{
|
||||
height: 3;
|
||||
align: right middle;
|
||||
padding: 0 1;
|
||||
display: none;
|
||||
}}
|
||||
#action-buttons.visible {{
|
||||
display: block;
|
||||
}}
|
||||
#action-buttons Button {{
|
||||
margin: 0 0 0 1;
|
||||
min-width: 10;
|
||||
}}
|
||||
#btn-resume {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#btn-resume:hover {{
|
||||
background: {_PRIMARY};
|
||||
}}
|
||||
#btn-fork {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#btn-fork:hover {{
|
||||
background: {_SECONDARY};
|
||||
}}
|
||||
.entity-title {{
|
||||
padding: 1 1 0 1;
|
||||
}}
|
||||
.entity-detail {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
.task-output-editor {{
|
||||
height: auto;
|
||||
max-height: 10;
|
||||
margin: 0 1 1 1;
|
||||
border: round {_DIM};
|
||||
}}
|
||||
.task-output-editor:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
.task-label {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
Tree {{
|
||||
background: {_BG_PANEL};
|
||||
}}
|
||||
Tree > .tree--cursor {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
"""
|
||||
|
||||
BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [
|
||||
("q", "quit", "Quit"),
|
||||
("r", "refresh", "Refresh"),
|
||||
]
|
||||
|
||||
def __init__(self, location: str = "./.checkpoints") -> None:
|
||||
super().__init__()
|
||||
self._location = location
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._selected_entry: dict[str, Any] | None = None
|
||||
self._input_keys: list[str] = []
|
||||
self._task_output_ids: list[tuple[int, str, str]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
with Horizontal(id="main-layout"):
|
||||
tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel")
|
||||
tree.show_root = True
|
||||
tree.guide_depth = 3
|
||||
yield tree
|
||||
with Vertical(id="detail-container"):
|
||||
yield Static("", id="status")
|
||||
with VerticalScroll(id="detail-scroll"):
|
||||
yield Static(
|
||||
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
|
||||
id="detail-header",
|
||||
)
|
||||
with Vertical(id="inputs-section"):
|
||||
yield Static("Inputs", id="inputs-label")
|
||||
with Horizontal(id="action-buttons"):
|
||||
yield Button("Resume", id="btn-resume")
|
||||
yield Button("Fork", id="btn-fork")
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
self._refresh_tree()
|
||||
self.query_one("#tree-panel", Tree).root.expand()
|
||||
|
||||
def _refresh_tree(self) -> None:
|
||||
self._entries = _load_entries(self._location)
|
||||
self._selected_entry = None
|
||||
|
||||
tree = self.query_one("#tree-panel", Tree)
|
||||
tree.clear()
|
||||
|
||||
if not self._entries:
|
||||
self.query_one("#detail-header", Static).update(
|
||||
f"[{_DIM}]No checkpoints in {self._location}[/]"
|
||||
)
|
||||
self.query_one("#status", Static).update("")
|
||||
self.sub_title = self._location
|
||||
return
|
||||
|
||||
# Group by branch
|
||||
branches: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for entry in self._entries:
|
||||
branch = entry.get("branch", "main")
|
||||
branches[branch].append(entry)
|
||||
|
||||
# Index checkpoint names to tree nodes so forks can attach
|
||||
node_by_name: dict[str, Any] = {}
|
||||
|
||||
def _make_label(e: dict[str, Any]) -> str:
|
||||
name = e.get("name", "")
|
||||
ts = e.get("ts") or ""
|
||||
trigger = e.get("trigger") or ""
|
||||
parts = [f"[bold]{_short_id(name)}[/]"]
|
||||
if ts:
|
||||
time_part = ts.split(" ")[-1] if " " in ts else ts
|
||||
parts.append(f"[{_DIM}]{time_part}[/]")
|
||||
if trigger:
|
||||
parts.append(f"[{_PRIMARY}]{trigger}[/]")
|
||||
return " ".join(parts)
|
||||
|
||||
fork_parents: set[str] = set()
|
||||
for branch_name, entries in branches.items():
|
||||
if branch_name == "main" or not entries:
|
||||
continue
|
||||
oldest = min(entries, key=lambda e: str(e.get("name", "")))
|
||||
first_parent = oldest.get("parent_id")
|
||||
if first_parent:
|
||||
fork_parents.add(str(first_parent))
|
||||
|
||||
def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None:
|
||||
"""Add a checkpoint node — expandable only if a fork attaches to it."""
|
||||
cp_id = _entry_id(e)
|
||||
if cp_id in fork_parents:
|
||||
node = parent_node.add(
|
||||
_make_label(e), data=e, expand=False, allow_expand=True
|
||||
)
|
||||
else:
|
||||
node = parent_node.add_leaf(_make_label(e), data=e)
|
||||
node_by_name[cp_id] = node
|
||||
|
||||
if "main" in branches:
|
||||
for entry in reversed(branches["main"]):
|
||||
_add_checkpoint(tree.root, entry)
|
||||
|
||||
fork_branches = [
|
||||
(name, sorted(entries, key=lambda e: str(e.get("name", ""))))
|
||||
for name, entries in branches.items()
|
||||
if name != "main"
|
||||
]
|
||||
remaining = fork_branches
|
||||
max_passes = len(remaining) + 1
|
||||
while remaining and max_passes > 0:
|
||||
max_passes -= 1
|
||||
deferred = []
|
||||
made_progress = False
|
||||
for branch_name, entries in remaining:
|
||||
first_parent = entries[0].get("parent_id") if entries else None
|
||||
if first_parent and str(first_parent) not in node_by_name:
|
||||
deferred.append((branch_name, entries))
|
||||
continue
|
||||
attach_to: Any = tree.root
|
||||
if first_parent:
|
||||
attach_to = node_by_name.get(str(first_parent), tree.root)
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
|
||||
)
|
||||
branch_node = attach_to.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
made_progress = True
|
||||
remaining = deferred
|
||||
if not made_progress:
|
||||
break
|
||||
|
||||
for branch_name, entries in remaining:
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] "
|
||||
f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]"
|
||||
)
|
||||
branch_node = tree.root.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
|
||||
count = len(self._entries)
|
||||
storage = "SQLite" if _is_sqlite(self._location) else "JSON"
|
||||
self.sub_title = self._location
|
||||
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
|
||||
|
||||
async def _show_detail(self, entry: dict[str, Any]) -> None:
|
||||
"""Update the detail panel for a checkpoint entry."""
|
||||
self._selected_entry = entry
|
||||
self.query_one("#action-buttons").add_class("visible")
|
||||
|
||||
detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
|
||||
|
||||
# Remove all dynamic children except the header — await so IDs are freed
|
||||
to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
|
||||
for child in to_remove:
|
||||
await child.remove()
|
||||
|
||||
# Header
|
||||
name = entry.get("name", "")
|
||||
ts = entry.get("ts") or "unknown"
|
||||
trigger = entry.get("trigger") or ""
|
||||
branch = entry.get("branch", "main")
|
||||
parent_id = entry.get("parent_id")
|
||||
|
||||
header_lines = [
|
||||
f"[bold {_PRIMARY}]{name}[/]",
|
||||
f"[{_DIM}]{'─' * 50}[/]",
|
||||
"",
|
||||
f" [bold]Time[/] {ts}",
|
||||
]
|
||||
if "size" in entry:
|
||||
header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
|
||||
header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
|
||||
if trigger:
|
||||
header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
|
||||
header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]")
|
||||
if parent_id:
|
||||
header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]")
|
||||
if "path" in entry:
|
||||
header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
|
||||
if "db" in entry:
|
||||
header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
|
||||
|
||||
self.query_one("#detail-header", Static).update("\n".join(header_lines))
|
||||
|
||||
# Entity details and editable task outputs — mounted flat for scrolling
|
||||
self._task_output_ids = []
|
||||
flat_task_idx = 0
|
||||
for ent_idx, ent in enumerate(entry.get("entities", [])):
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
completed = ent.get("tasks_completed")
|
||||
total = ent.get("tasks_total")
|
||||
entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
|
||||
if completed is not None and total is not None:
|
||||
entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
|
||||
await detail_scroll.mount(Static(entity_title, classes="entity-title"))
|
||||
await detail_scroll.mount(
|
||||
Static(_build_entity_header(ent), classes="entity-detail")
|
||||
)
|
||||
|
||||
tasks = ent.get("tasks", [])
|
||||
for i, task in enumerate(tasks):
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 55:
|
||||
desc = desc[:52] + "..."
|
||||
if task.get("completed"):
|
||||
icon = "[green]✓[/]"
|
||||
await detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
output_text = task.get("output", "")
|
||||
editor_id = f"task-output-{ent_idx}-{i}"
|
||||
await detail_scroll.mount(
|
||||
TextArea(
|
||||
str(output_text),
|
||||
classes="task-output-editor",
|
||||
id=editor_id,
|
||||
)
|
||||
)
|
||||
self._task_output_ids.append(
|
||||
(flat_task_idx, editor_id, str(output_text))
|
||||
)
|
||||
else:
|
||||
icon = "[yellow]○[/]"
|
||||
await detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
flat_task_idx += 1
|
||||
|
||||
# Build input fields
|
||||
await self._build_input_fields(entry.get("inputs", {}))
|
||||
|
||||
async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
|
||||
"""Rebuild the inputs section with one field per input key."""
|
||||
section = self.query_one("#inputs-section")
|
||||
|
||||
# Remove old dynamic children — await so IDs are freed
|
||||
for widget in list(section.query(".input-row, .no-inputs")):
|
||||
await widget.remove()
|
||||
|
||||
self._input_keys = []
|
||||
|
||||
if not inputs:
|
||||
await section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
|
||||
section.add_class("visible")
|
||||
return
|
||||
|
||||
for key, value in inputs.items():
|
||||
self._input_keys.append(key)
|
||||
row = Horizontal(classes="input-row")
|
||||
row.compose_add_child(Static(f"[bold]{key}[/]"))
|
||||
row.compose_add_child(
|
||||
Input(value=str(value), placeholder=key, id=f"input-{key}")
|
||||
)
|
||||
await section.mount(row)
|
||||
|
||||
section.add_class("visible")
|
||||
|
||||
def _collect_inputs(self) -> dict[str, Any] | None:
|
||||
"""Collect current values from input fields."""
|
||||
if not self._input_keys:
|
||||
return None
|
||||
result: dict[str, Any] = {}
|
||||
for key in self._input_keys:
|
||||
widget = self.query_one(f"#input-{key}", Input)
|
||||
result[key] = widget.value
|
||||
return result
|
||||
|
||||
def _collect_task_overrides(self) -> dict[int, str] | None:
|
||||
"""Collect edited task outputs. Returns only changed values."""
|
||||
if not self._task_output_ids or self._selected_entry is None:
|
||||
return None
|
||||
overrides: dict[int, str] = {}
|
||||
for task_idx, editor_id, original in self._task_output_ids:
|
||||
editor = self.query_one(f"#{editor_id}", TextArea)
|
||||
if editor.text != original:
|
||||
overrides[task_idx] = editor.text
|
||||
return overrides or None
|
||||
|
||||
def _detect_entity_type(self, entry: dict[str, Any]) -> Literal["crew", "flow"]:
|
||||
"""Infer the top-level entity type from checkpoint entities."""
|
||||
for ent in entry.get("entities", []):
|
||||
if ent.get("type") == "flow":
|
||||
return "flow"
|
||||
return "crew"
|
||||
|
||||
def _resolve_location(self, entry: dict[str, Any]) -> str:
|
||||
"""Get the restore location string for a checkpoint entry."""
|
||||
if "path" in entry:
|
||||
return str(entry["path"])
|
||||
if _is_sqlite(self._location):
|
||||
return f"{self._location}#{entry['name']}"
|
||||
return str(entry.get("name", ""))
|
||||
|
||||
async def on_tree_node_highlighted(
|
||||
self, event: Tree.NodeHighlighted[dict[str, Any]]
|
||||
) -> None:
|
||||
if event.node.data is not None:
|
||||
await self._show_detail(event.node.data)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if self._selected_entry is None:
|
||||
return
|
||||
inputs = self._collect_inputs()
|
||||
overrides = self._collect_task_overrides()
|
||||
loc = self._resolve_location(self._selected_entry)
|
||||
etype = self._detect_entity_type(self._selected_entry)
|
||||
if event.button.id == "btn-resume":
|
||||
self.exit((loc, "resume", inputs, overrides, etype))
|
||||
elif event.button.id == "btn-fork":
|
||||
self.exit((loc, "fork", inputs, overrides, etype))
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
self._refresh_tree()
|
||||
|
||||
|
||||
def _apply_task_overrides(crew: Any, task_overrides: dict[int, str]) -> None:
|
||||
"""Apply task output overrides to a restored Crew and print modifications."""
|
||||
import click
|
||||
|
||||
click.echo("Modifications:")
|
||||
overridden_agents: set[int] = set()
|
||||
for task_idx, new_output in task_overrides.items():
|
||||
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
|
||||
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
crew.tasks[task_idx].output.raw = new_output
|
||||
preview = new_output.replace("\n", " ")
|
||||
if len(preview) > 80:
|
||||
preview = preview[:77] + "..."
|
||||
click.echo(f" Task {task_idx + 1}: {desc}")
|
||||
click.echo(f" -> {preview}")
|
||||
agent = crew.tasks[task_idx].agent
|
||||
if agent and agent.agent_executor:
|
||||
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
|
||||
messages = agent.agent_executor.messages
|
||||
system_positions = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "system"
|
||||
]
|
||||
if nth < len(system_positions):
|
||||
seg_start = system_positions[nth]
|
||||
seg_end = (
|
||||
system_positions[nth + 1]
|
||||
if nth + 1 < len(system_positions)
|
||||
else len(messages)
|
||||
)
|
||||
for j in range(seg_end - 1, seg_start, -1):
|
||||
if messages[j].get("role") == "assistant":
|
||||
messages[j]["content"] = new_output
|
||||
break
|
||||
overridden_agents.add(id(agent))
|
||||
|
||||
earliest = min(task_overrides)
|
||||
for offset, subsequent in enumerate(crew.tasks[earliest + 1 :], start=earliest + 1):
|
||||
if subsequent.output and offset not in task_overrides:
|
||||
subsequent.output = None
|
||||
if subsequent.agent and subsequent.agent.agent_executor:
|
||||
subsequent.agent.agent_executor._resuming = False
|
||||
if id(subsequent.agent) not in overridden_agents:
|
||||
subsequent.agent.agent_executor.messages = []
|
||||
click.echo()
|
||||
|
||||
|
||||
async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
"""Async implementation of the checkpoint TUI flow."""
|
||||
import click
|
||||
|
||||
app = CheckpointTUI(location=location)
|
||||
selection = await app.run_async()
|
||||
|
||||
if selection is None:
|
||||
return
|
||||
|
||||
selected, action, inputs, task_overrides, entity_type = selection
|
||||
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
config = CheckpointConfig(restore_from=selected)
|
||||
|
||||
if entity_type == "flow":
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking flow from: {selected}\n")
|
||||
flow = Flow.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming flow from: {selected}\n")
|
||||
flow = Flow.from_checkpoint(config)
|
||||
|
||||
if task_overrides:
|
||||
from crewai.crew import Crew as CrewCls
|
||||
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is not None:
|
||||
flat_offset = 0
|
||||
for entity in state.root:
|
||||
if not isinstance(entity, CrewCls) or not entity.tasks:
|
||||
continue
|
||||
n = len(entity.tasks)
|
||||
local = {
|
||||
idx - flat_offset: out
|
||||
for idx, out in task_overrides.items()
|
||||
if flat_offset <= idx < flat_offset + n
|
||||
}
|
||||
if local:
|
||||
_apply_task_overrides(entity, local)
|
||||
flat_offset += n
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await flow.kickoff_async(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
return
|
||||
|
||||
from crewai.crew import Crew
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking from: {selected}\n")
|
||||
crew = Crew.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming from: {selected}\n")
|
||||
crew = Crew.from_checkpoint(config)
|
||||
|
||||
if task_overrides:
|
||||
_apply_task_overrides(crew, task_overrides)
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await crew.akickoff(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
def run_checkpoint_tui(location: str = "./.checkpoints") -> None:
|
||||
"""Launch the checkpoint browser TUI."""
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_run_checkpoint_tui_async(location))
|
||||
@@ -18,6 +18,7 @@ from crewai.cli.install_crew import install_crew
|
||||
from crewai.cli.kickoff_flow import kickoff_flow
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
from crewai.cli.plot_flow import plot_flow
|
||||
from crewai.cli.remote_template.main import TemplateCommand
|
||||
from crewai.cli.replay_from_task import replay_task_command
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
from crewai.cli.run_crew import run_crew
|
||||
@@ -392,10 +393,15 @@ def deploy() -> None:
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool) -> None:
|
||||
@click.option(
|
||||
"--skip-validate",
|
||||
is_flag=True,
|
||||
help="Skip the pre-deploy validation checks.",
|
||||
)
|
||||
def deploy_create(yes: bool, skip_validate: bool) -> None:
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
deploy_cmd.create_crew(yes, skip_validate=skip_validate)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
@@ -407,10 +413,28 @@ def deploy_list() -> None:
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: str | None) -> None:
|
||||
@click.option(
|
||||
"--skip-validate",
|
||||
is_flag=True,
|
||||
help="Skip the pre-deploy validation checks.",
|
||||
)
|
||||
def deploy_push(uuid: str | None, skip_validate: bool) -> None:
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
deploy_cmd.deploy(uuid=uuid, skip_validate=skip_validate)
|
||||
|
||||
|
||||
@deploy.command(name="validate")
|
||||
def deploy_validate() -> None:
|
||||
"""Validate the current project against common deployment failures.
|
||||
|
||||
Runs the same pre-deploy checks that `crewai deploy create` and
|
||||
`crewai deploy push` run automatically, without contacting the platform.
|
||||
Exits non-zero if any blocking issues are found.
|
||||
"""
|
||||
from crewai.cli.deploy.validate import run_validate_command
|
||||
|
||||
run_validate_command()
|
||||
|
||||
|
||||
@deploy.command(name="status")
|
||||
@@ -473,6 +497,33 @@ def tool_publish(is_public: bool, force: bool) -> None:
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def template() -> None:
|
||||
"""Browse and install project templates."""
|
||||
|
||||
|
||||
@template.command(name="list")
|
||||
def template_list() -> None:
|
||||
"""List available templates and select one to install."""
|
||||
template_cmd = TemplateCommand()
|
||||
template_cmd.list_templates()
|
||||
|
||||
|
||||
@template.command(name="add")
|
||||
@click.argument("name")
|
||||
@click.option(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory name for the template (defaults to template name)",
|
||||
)
|
||||
def template_add(name: str, output_dir: str | None) -> None:
|
||||
"""Add a template to the current directory."""
|
||||
template_cmd = TemplateCommand()
|
||||
template_cmd.add_template(name, output_dir)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow() -> None:
|
||||
"""Flow related commands."""
|
||||
@@ -786,27 +837,83 @@ def traces_status() -> None:
|
||||
console.print(panel)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def checkpoint() -> None:
|
||||
"""Inspect checkpoint files."""
|
||||
@crewai.group(invoke_without_command=True)
|
||||
@click.option(
|
||||
"--location", default="./.checkpoints", help="Checkpoint directory or SQLite file."
|
||||
)
|
||||
@click.pass_context
|
||||
def checkpoint(ctx: click.Context, location: str) -> None:
|
||||
"""Browse and inspect checkpoints. Launches a TUI when called without a subcommand."""
|
||||
from crewai.cli.checkpoint_cli import _detect_location
|
||||
|
||||
location = _detect_location(location)
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["location"] = location
|
||||
if ctx.invoked_subcommand is None:
|
||||
from crewai.cli.checkpoint_tui import run_checkpoint_tui
|
||||
|
||||
run_checkpoint_tui(location)
|
||||
|
||||
|
||||
@checkpoint.command("list")
|
||||
@click.argument("location", default="./.checkpoints")
|
||||
def checkpoint_list(location: str) -> None:
|
||||
"""List checkpoints in a directory."""
|
||||
from crewai.cli.checkpoint_cli import list_checkpoints
|
||||
from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints
|
||||
|
||||
list_checkpoints(location)
|
||||
list_checkpoints(_detect_location(location))
|
||||
|
||||
|
||||
@checkpoint.command("info")
|
||||
@click.argument("path", default="./.checkpoints")
|
||||
def checkpoint_info(path: str) -> None:
|
||||
"""Show details of a checkpoint. Pass a file or directory for latest."""
|
||||
from crewai.cli.checkpoint_cli import info_checkpoint
|
||||
from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint
|
||||
|
||||
info_checkpoint(path)
|
||||
info_checkpoint(_detect_location(path))
|
||||
|
||||
|
||||
@checkpoint.command("resume")
|
||||
@click.argument("checkpoint_id", required=False, default=None)
|
||||
@click.pass_context
|
||||
def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None:
|
||||
"""Resume from a checkpoint. Defaults to the most recent."""
|
||||
from crewai.cli.checkpoint_cli import resume_checkpoint
|
||||
|
||||
resume_checkpoint(ctx.obj["location"], checkpoint_id)
|
||||
|
||||
|
||||
@checkpoint.command("diff")
|
||||
@click.argument("id1")
|
||||
@click.argument("id2")
|
||||
@click.pass_context
|
||||
def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None:
|
||||
"""Compare two checkpoints side-by-side."""
|
||||
from crewai.cli.checkpoint_cli import diff_checkpoints
|
||||
|
||||
diff_checkpoints(ctx.obj["location"], id1, id2)
|
||||
|
||||
|
||||
@checkpoint.command("prune")
|
||||
@click.option(
|
||||
"--keep", type=int, default=None, help="Keep the N most recent checkpoints."
|
||||
)
|
||||
@click.option(
|
||||
"--older-than",
|
||||
default=None,
|
||||
help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run", is_flag=True, help="Show what would be pruned without deleting."
|
||||
)
|
||||
@click.pass_context
|
||||
def checkpoint_prune(
|
||||
ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool
|
||||
) -> None:
|
||||
"""Remove old checkpoints."""
|
||||
from crewai.cli.checkpoint_cli import prune_checkpoints
|
||||
|
||||
prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -13,7 +13,6 @@ from packaging import version
|
||||
import tomli
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -21,6 +20,7 @@ from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.types import LLMMessage
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
|
||||
|
||||
@@ -4,12 +4,35 @@ from rich.console import Console
|
||||
|
||||
from crewai.cli import git
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.deploy.validate import validate_project
|
||||
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def _run_predeploy_validation(skip_validate: bool) -> bool:
|
||||
"""Run pre-deploy validation unless skipped.
|
||||
|
||||
Returns True if deployment should proceed, False if it should abort.
|
||||
"""
|
||||
if skip_validate:
|
||||
console.print(
|
||||
"[yellow]Skipping pre-deploy validation (--skip-validate).[/yellow]"
|
||||
)
|
||||
return True
|
||||
|
||||
console.print("Running pre-deploy validation...", style="bold blue")
|
||||
validator = validate_project()
|
||||
if not validator.ok:
|
||||
console.print(
|
||||
"\n[bold red]Pre-deploy validation failed. "
|
||||
"Fix the issues above or re-run with --skip-validate.[/bold red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
@@ -60,13 +83,16 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
|
||||
)
|
||||
|
||||
def deploy(self, uuid: str | None = None) -> None:
|
||||
def deploy(self, uuid: str | None = None, skip_validate: bool = False) -> None:
|
||||
"""
|
||||
Deploy a crew using either UUID or project name.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
skip_validate (bool): Skip pre-deploy validation checks.
|
||||
"""
|
||||
if not _run_predeploy_validation(skip_validate):
|
||||
return
|
||||
self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
@@ -80,10 +106,16 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._validate_response(response)
|
||||
self._display_deployment_info(response.json())
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
def create_crew(self, confirm: bool = False, skip_validate: bool = False) -> None:
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
|
||||
Args:
|
||||
confirm (bool): Whether to skip the interactive confirmation prompt.
|
||||
skip_validate (bool): Skip pre-deploy validation checks.
|
||||
"""
|
||||
if not _run_predeploy_validation(skip_validate):
|
||||
return
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
|
||||
845
lib/crewai/src/crewai/cli/deploy/validate.py
Normal file
845
lib/crewai/src/crewai/cli/deploy/validate.py
Normal file
@@ -0,0 +1,845 @@
|
||||
"""Pre-deploy validation for CrewAI projects.
|
||||
|
||||
Catches locally what a deploy would reject at build or runtime so users
|
||||
don't burn deployment attempts on fixable project-structure problems.
|
||||
|
||||
Each check is grouped into one of:
|
||||
- ERROR: will block a deployment; validator exits non-zero.
|
||||
- WARNING: may still deploy but is almost always a deployment bug; printed
|
||||
but does not block.
|
||||
|
||||
The individual checks mirror the categories observed in production
|
||||
deployment-failure logs:
|
||||
|
||||
1. pyproject.toml present with ``[project].name``
|
||||
2. lockfile (``uv.lock`` or ``poetry.lock``) present and not stale
|
||||
3. package directory at ``src/<package>/`` exists (no empty name, no egg-info)
|
||||
4. standard crew files: ``crew.py``, ``config/agents.yaml``, ``config/tasks.yaml``
|
||||
5. flow entrypoint: ``main.py`` with a Flow subclass
|
||||
6. hatch wheel target resolves (packages = [...] or default dir matches name)
|
||||
7. crew/flow module imports cleanly (catches ``@CrewBase not found``,
|
||||
``No Flow subclass found``, provider import errors)
|
||||
8. environment variables referenced in code vs ``.env`` / deployment env
|
||||
9. installed crewai vs lockfile pin (catches missing-attribute failures from
|
||||
stale pins)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.utils import parse_toml
|
||||
|
||||
|
||||
console = Console()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Severity(str, Enum):
|
||||
"""Severity of a validation finding."""
|
||||
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""A single finding from a validation check.
|
||||
|
||||
Attributes:
|
||||
severity: whether this blocks deploy or is advisory.
|
||||
code: stable short identifier, used in tests and docs
|
||||
(e.g. ``missing_pyproject``, ``stale_lockfile``).
|
||||
title: one-line summary shown to the user.
|
||||
detail: optional multi-line explanation.
|
||||
hint: optional remediation suggestion.
|
||||
"""
|
||||
|
||||
severity: Severity
|
||||
code: str
|
||||
title: str
|
||||
detail: str = ""
|
||||
hint: str = ""
|
||||
|
||||
|
||||
# Maps known provider env var names → label used in hint messages.
|
||||
_KNOWN_API_KEY_HINTS: dict[str, str] = {
|
||||
"OPENAI_API_KEY": "OpenAI",
|
||||
"ANTHROPIC_API_KEY": "Anthropic",
|
||||
"GOOGLE_API_KEY": "Google",
|
||||
"GEMINI_API_KEY": "Gemini",
|
||||
"AZURE_OPENAI_API_KEY": "Azure OpenAI",
|
||||
"AZURE_API_KEY": "Azure",
|
||||
"AWS_ACCESS_KEY_ID": "AWS",
|
||||
"AWS_SECRET_ACCESS_KEY": "AWS",
|
||||
"COHERE_API_KEY": "Cohere",
|
||||
"GROQ_API_KEY": "Groq",
|
||||
"MISTRAL_API_KEY": "Mistral",
|
||||
"TAVILY_API_KEY": "Tavily",
|
||||
"SERPER_API_KEY": "Serper",
|
||||
"SERPLY_API_KEY": "Serply",
|
||||
"PERPLEXITY_API_KEY": "Perplexity",
|
||||
"DEEPSEEK_API_KEY": "DeepSeek",
|
||||
"OPENROUTER_API_KEY": "OpenRouter",
|
||||
"FIRECRAWL_API_KEY": "Firecrawl",
|
||||
"EXA_API_KEY": "Exa",
|
||||
"BROWSERBASE_API_KEY": "Browserbase",
|
||||
}
|
||||
|
||||
|
||||
def normalize_package_name(project_name: str) -> str:
|
||||
"""Normalize a pyproject project.name into a Python package directory name.
|
||||
|
||||
Mirrors the rules in ``crewai.cli.create_crew.create_crew`` so the
|
||||
validator agrees with the scaffolder about where ``src/<pkg>/`` should
|
||||
live.
|
||||
"""
|
||||
folder = project_name.replace(" ", "_").replace("-", "_").lower()
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", folder)
|
||||
|
||||
|
||||
class DeployValidator:
|
||||
"""Runs the full pre-deploy validation suite against a project directory."""
|
||||
|
||||
def __init__(self, project_root: Path | None = None) -> None:
|
||||
self.project_root: Path = (project_root or Path.cwd()).resolve()
|
||||
self.results: list[ValidationResult] = []
|
||||
self._pyproject: dict[str, Any] | None = None
|
||||
self._project_name: str | None = None
|
||||
self._package_name: str | None = None
|
||||
self._package_dir: Path | None = None
|
||||
self._is_flow: bool = False
|
||||
|
||||
def _add(
|
||||
self,
|
||||
severity: Severity,
|
||||
code: str,
|
||||
title: str,
|
||||
detail: str = "",
|
||||
hint: str = "",
|
||||
) -> None:
|
||||
self.results.append(
|
||||
ValidationResult(
|
||||
severity=severity,
|
||||
code=code,
|
||||
title=title,
|
||||
detail=detail,
|
||||
hint=hint,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def errors(self) -> list[ValidationResult]:
|
||||
return [r for r in self.results if r.severity is Severity.ERROR]
|
||||
|
||||
@property
|
||||
def warnings(self) -> list[ValidationResult]:
|
||||
return [r for r in self.results if r.severity is Severity.WARNING]
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def run(self) -> list[ValidationResult]:
|
||||
"""Run all checks. Later checks are skipped when earlier ones make
|
||||
them impossible (e.g. no pyproject.toml → no lockfile check)."""
|
||||
if not self._check_pyproject():
|
||||
return self.results
|
||||
|
||||
self._check_lockfile()
|
||||
|
||||
if not self._check_package_dir():
|
||||
self._check_hatch_wheel_target()
|
||||
return self.results
|
||||
|
||||
if self._is_flow:
|
||||
self._check_flow_entrypoint()
|
||||
else:
|
||||
self._check_crew_entrypoint()
|
||||
self._check_config_yamls()
|
||||
|
||||
self._check_hatch_wheel_target()
|
||||
self._check_module_imports()
|
||||
self._check_env_vars()
|
||||
self._check_version_vs_lockfile()
|
||||
|
||||
return self.results
|
||||
|
||||
def _check_pyproject(self) -> bool:
|
||||
pyproject_path = self.project_root / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_pyproject",
|
||||
"Cannot find pyproject.toml",
|
||||
detail=(
|
||||
f"Expected pyproject.toml at {pyproject_path}. "
|
||||
"CrewAI projects must be installable Python packages."
|
||||
),
|
||||
hint="Run `crewai create crew <name>` to scaffold a valid project layout.",
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._pyproject = parse_toml(pyproject_path.read_text())
|
||||
except Exception as e:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"invalid_pyproject",
|
||||
"pyproject.toml is not valid TOML",
|
||||
detail=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
project = self._pyproject.get("project") or {}
|
||||
name = project.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_project_name",
|
||||
"pyproject.toml is missing [project].name",
|
||||
detail=(
|
||||
"Without a project name the platform cannot resolve your "
|
||||
"package directory (this produces errors like "
|
||||
"'Cannot find src//crew.py')."
|
||||
),
|
||||
hint='Set a `name = "..."` field under `[project]` in pyproject.toml.',
|
||||
)
|
||||
return False
|
||||
|
||||
self._project_name = name
|
||||
self._package_name = normalize_package_name(name)
|
||||
self._is_flow = (self._pyproject.get("tool") or {}).get("crewai", {}).get(
|
||||
"type"
|
||||
) == "flow"
|
||||
return True
|
||||
|
||||
def _check_lockfile(self) -> None:
|
||||
uv_lock = self.project_root / "uv.lock"
|
||||
poetry_lock = self.project_root / "poetry.lock"
|
||||
pyproject = self.project_root / "pyproject.toml"
|
||||
|
||||
if not uv_lock.exists() and not poetry_lock.exists():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_lockfile",
|
||||
"Expected to find at least one of these files: uv.lock or poetry.lock",
|
||||
hint=(
|
||||
"Run `uv lock` (recommended) or `poetry lock` in your project "
|
||||
"directory, commit the lockfile, then redeploy."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lockfile = uv_lock if uv_lock.exists() else poetry_lock
|
||||
try:
|
||||
if lockfile.stat().st_mtime < pyproject.stat().st_mtime:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"stale_lockfile",
|
||||
f"{lockfile.name} is older than pyproject.toml",
|
||||
detail=(
|
||||
"Your lockfile may not reflect recent dependency changes. "
|
||||
"The platform resolves from the lockfile, so deployed "
|
||||
"dependencies may differ from local."
|
||||
),
|
||||
hint="Run `uv lock` (or `poetry lock`) and commit the result.",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _check_package_dir(self) -> bool:
|
||||
if self._package_name is None:
|
||||
return False
|
||||
|
||||
src_dir = self.project_root / "src"
|
||||
if not src_dir.is_dir():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_src_dir",
|
||||
"Missing src/ directory",
|
||||
detail=(
|
||||
"CrewAI deployments expect a src-layout project: "
|
||||
f"src/{self._package_name}/crew.py (or main.py for flows)."
|
||||
),
|
||||
hint="Run `crewai create crew <name>` to see the expected layout.",
|
||||
)
|
||||
return False
|
||||
|
||||
package_dir = src_dir / self._package_name
|
||||
if not package_dir.is_dir():
|
||||
siblings = [
|
||||
p.name
|
||||
for p in src_dir.iterdir()
|
||||
if p.is_dir() and not p.name.endswith(".egg-info")
|
||||
]
|
||||
egg_info = [
|
||||
p.name for p in src_dir.iterdir() if p.name.endswith(".egg-info")
|
||||
]
|
||||
|
||||
hint_parts = [
|
||||
f'Create src/{self._package_name}/ to match [project].name = "{self._project_name}".'
|
||||
]
|
||||
if siblings:
|
||||
hint_parts.append(
|
||||
f"Found other package directories: {', '.join(siblings)}. "
|
||||
f"Either rename one to '{self._package_name}' or update [project].name."
|
||||
)
|
||||
if egg_info:
|
||||
hint_parts.append(
|
||||
f"Delete stale build artifacts: {', '.join(egg_info)} "
|
||||
"(these confuse the platform's package discovery)."
|
||||
)
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_package_dir",
|
||||
f"Cannot find src/{self._package_name}/",
|
||||
detail=(
|
||||
"The platform looks for your crew source under "
|
||||
"src/<package_name>/, derived from [project].name."
|
||||
),
|
||||
hint=" ".join(hint_parts),
|
||||
)
|
||||
return False
|
||||
|
||||
for p in src_dir.iterdir():
|
||||
if p.name.endswith(".egg-info"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"stale_egg_info",
|
||||
f"Stale build artifact in src/: {p.name}",
|
||||
detail=(
|
||||
".egg-info directories can be mistaken for your package "
|
||||
"and cause 'Cannot find src/<name>.egg-info/crew.py' errors."
|
||||
),
|
||||
hint=f"Delete {p} and add `*.egg-info/` to .gitignore.",
|
||||
)
|
||||
|
||||
self._package_dir = package_dir
|
||||
return True
|
||||
|
||||
def _check_crew_entrypoint(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
crew_py = self._package_dir / "crew.py"
|
||||
if not crew_py.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_crew_py",
|
||||
f"Cannot find {crew_py.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"Standard crew projects must define a Crew class decorated "
|
||||
"with @CrewBase inside crew.py."
|
||||
),
|
||||
hint=(
|
||||
"Create crew.py with an @CrewBase-annotated class, or set "
|
||||
'`[tool.crewai] type = "flow"` in pyproject.toml if this is a flow.'
|
||||
),
|
||||
)
|
||||
|
||||
def _check_config_yamls(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
config_dir = self._package_dir / "config"
|
||||
if not config_dir.is_dir():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_config_dir",
|
||||
f"Cannot find {config_dir.relative_to(self.project_root)}",
|
||||
hint="Create a config/ directory with agents.yaml and tasks.yaml.",
|
||||
)
|
||||
return
|
||||
|
||||
for yaml_name in ("agents.yaml", "tasks.yaml"):
|
||||
yaml_path = config_dir / yaml_name
|
||||
if not yaml_path.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
f"missing_{yaml_name.replace('.', '_')}",
|
||||
f"Cannot find {yaml_path.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"CrewAI loads agent and task config from these files; "
|
||||
"missing them causes empty-config warnings and runtime crashes."
|
||||
),
|
||||
)
|
||||
|
||||
def _check_flow_entrypoint(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
main_py = self._package_dir / "main.py"
|
||||
if not main_py.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_flow_main",
|
||||
f"Cannot find {main_py.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"Flow projects must define a Flow subclass in main.py. "
|
||||
'This project has `[tool.crewai] type = "flow"` set.'
|
||||
),
|
||||
hint="Create main.py with a `class MyFlow(Flow[...])`.",
|
||||
)
|
||||
|
||||
def _check_hatch_wheel_target(self) -> None:
|
||||
if not self._pyproject:
|
||||
return
|
||||
|
||||
build_system = self._pyproject.get("build-system") or {}
|
||||
backend = build_system.get("build-backend", "")
|
||||
if "hatchling" not in backend:
|
||||
return
|
||||
|
||||
hatch_wheel = (
|
||||
(self._pyproject.get("tool") or {})
|
||||
.get("hatch", {})
|
||||
.get("build", {})
|
||||
.get("targets", {})
|
||||
.get("wheel", {})
|
||||
)
|
||||
if hatch_wheel.get("packages") or hatch_wheel.get("only-include"):
|
||||
return
|
||||
|
||||
if self._package_dir and self._package_dir.is_dir():
|
||||
return
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"hatch_wheel_target_missing",
|
||||
"Hatchling cannot determine which files to ship",
|
||||
detail=(
|
||||
"Your pyproject uses hatchling but has no "
|
||||
"[tool.hatch.build.targets.wheel] configuration and no "
|
||||
"directory matching your project name."
|
||||
),
|
||||
hint=(
|
||||
"Add:\n"
|
||||
" [tool.hatch.build.targets.wheel]\n"
|
||||
f' packages = ["src/{self._package_name}"]'
|
||||
),
|
||||
)
|
||||
|
||||
def _check_module_imports(self) -> None:
|
||||
"""Import the user's crew/flow via `uv run` so the check sees the same
|
||||
package versions as `crewai run` would. Result is reported as JSON on
|
||||
the subprocess's stdout."""
|
||||
script = (
|
||||
"import json, sys, traceback, os\n"
|
||||
"os.chdir(sys.argv[1])\n"
|
||||
"try:\n"
|
||||
" from crewai.cli.utils import get_crews, get_flows\n"
|
||||
" is_flow = sys.argv[2] == 'flow'\n"
|
||||
" if is_flow:\n"
|
||||
" instances = get_flows()\n"
|
||||
" kind = 'flow'\n"
|
||||
" else:\n"
|
||||
" instances = get_crews()\n"
|
||||
" kind = 'crew'\n"
|
||||
" print(json.dumps({'ok': True, 'kind': kind, 'count': len(instances)}))\n"
|
||||
"except BaseException as e:\n"
|
||||
" print(json.dumps({\n"
|
||||
" 'ok': False,\n"
|
||||
" 'error_type': type(e).__name__,\n"
|
||||
" 'error': str(e),\n"
|
||||
" 'traceback': traceback.format_exc(),\n"
|
||||
" }))\n"
|
||||
)
|
||||
|
||||
uv_path = shutil.which("uv")
|
||||
if uv_path is None:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"uv_not_found",
|
||||
"Skipping import check: `uv` not installed",
|
||||
hint="Install uv: https://docs.astral.sh/uv/",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
proc = subprocess.run( # noqa: S603 - args constructed from trusted inputs
|
||||
[
|
||||
uv_path,
|
||||
"run",
|
||||
"python",
|
||||
"-c",
|
||||
script,
|
||||
str(self.project_root),
|
||||
"flow" if self._is_flow else "crew",
|
||||
],
|
||||
cwd=self.project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
check=False,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_timeout",
|
||||
"Importing your crew/flow module timed out after 120s",
|
||||
detail=(
|
||||
"User code may be making network calls or doing heavy work "
|
||||
"at import time. Move that work into agent methods."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# The payload is the last JSON object on stdout; user code may print
|
||||
# other lines before it.
|
||||
payload: dict[str, Any] | None = None
|
||||
for line in reversed(proc.stdout.splitlines()):
|
||||
line = line.strip()
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if payload is None:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_failed",
|
||||
"Could not import your crew/flow module",
|
||||
detail=(proc.stderr or proc.stdout or "").strip()[:1500],
|
||||
hint="Run `crewai run` locally first to reproduce the error.",
|
||||
)
|
||||
return
|
||||
|
||||
if payload.get("ok"):
|
||||
if payload.get("count", 0) == 0:
|
||||
kind = payload.get("kind", "crew")
|
||||
if kind == "flow":
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_flow_subclass",
|
||||
"No Flow subclass found in the module",
|
||||
hint=(
|
||||
"main.py must define a class extending "
|
||||
"`crewai.flow.Flow`, instantiable with no arguments."
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_crewbase_class",
|
||||
"Crew class annotated with @CrewBase not found",
|
||||
hint=(
|
||||
"Decorate your crew class with @CrewBase from "
|
||||
"crewai.project (see `crewai create crew` template)."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
err_msg = str(payload.get("error", ""))
|
||||
err_type = str(payload.get("error_type", "Exception"))
|
||||
tb = str(payload.get("traceback", ""))
|
||||
self._classify_import_error(err_type, err_msg, tb)
|
||||
|
||||
def _classify_import_error(self, err_type: str, err_msg: str, tb: str) -> None:
|
||||
"""Turn a raw import-time exception into a user-actionable finding."""
|
||||
# Must be checked before the generic "native provider" branch below:
|
||||
# the extras-missing message contains the same phrase. Providers
|
||||
# format the install command as plain text (`to install: uv add
|
||||
# "crewai[extra]"`); also tolerate backtick-delimited variants.
|
||||
m = re.search(
|
||||
r"(?P<pkg>[A-Za-z0-9_ -]+?)\s+native provider not available"
|
||||
r".*?to install:\s*`?(?P<cmd>uv add [\"']crewai\[[^\]]+\][\"'])`?",
|
||||
err_msg,
|
||||
)
|
||||
if m:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_provider_extra",
|
||||
f"{m.group('pkg').strip()} provider extra not installed",
|
||||
hint=f"Run: {m.group('cmd')}",
|
||||
)
|
||||
return
|
||||
|
||||
# crewai.llm.LLM.__new__ wraps provider init errors as
|
||||
# ImportError("Error importing native provider: ...").
|
||||
if "Error importing native provider" in err_msg or "native provider" in err_msg:
|
||||
missing_key = self._extract_missing_api_key(err_msg)
|
||||
if missing_key:
|
||||
provider = _KNOWN_API_KEY_HINTS.get(missing_key, missing_key)
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"llm_init_missing_key",
|
||||
f"LLM is constructed at import time but {missing_key} is not set",
|
||||
detail=(
|
||||
f"Your crew instantiates a {provider} LLM during module "
|
||||
"load (e.g. in a class field default or @crew method). "
|
||||
f"The {provider} provider currently requires {missing_key} "
|
||||
"at construction time, so this will fail on the platform "
|
||||
"unless the key is set in your deployment environment."
|
||||
),
|
||||
hint=(
|
||||
f"Add {missing_key} to your deployment's Environment "
|
||||
"Variables before deploying, or move LLM construction "
|
||||
"inside agent methods so it runs lazily."
|
||||
),
|
||||
)
|
||||
return
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"llm_provider_init_failed",
|
||||
"LLM native provider failed to initialize",
|
||||
detail=err_msg,
|
||||
hint=(
|
||||
"Check your LLM(model=...) configuration and provider-specific "
|
||||
"extras (e.g. `uv add 'crewai[azure-ai-inference]'` for Azure)."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if err_type == "KeyError":
|
||||
key = err_msg.strip("'\"")
|
||||
if key in _KNOWN_API_KEY_HINTS or key.endswith("_API_KEY"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"env_var_read_at_import",
|
||||
f"{key} is read at import time via os.environ[...]",
|
||||
detail=(
|
||||
"Using os.environ[...] (rather than os.getenv(...)) "
|
||||
"at module scope crashes the build if the key isn't set."
|
||||
),
|
||||
hint=(
|
||||
f"Either add {key} as a deployment env var, or switch "
|
||||
"to os.getenv() and move the access inside agent methods."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if "Crew class annotated with @CrewBase not found" in err_msg:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_crewbase_class",
|
||||
"Crew class annotated with @CrewBase not found",
|
||||
detail=err_msg,
|
||||
)
|
||||
return
|
||||
if "No Flow subclass found" in err_msg:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_flow_subclass",
|
||||
"No Flow subclass found in the module",
|
||||
detail=err_msg,
|
||||
)
|
||||
return
|
||||
|
||||
if (
|
||||
err_type == "AttributeError"
|
||||
and "has no attribute '_load_response_format'" in err_msg
|
||||
):
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"stale_crewai_pin",
|
||||
"Your lockfile pins a crewai version missing `_load_response_format`",
|
||||
detail=err_msg,
|
||||
hint=(
|
||||
"Run `uv lock --upgrade-package crewai` (or `poetry update crewai`) "
|
||||
"to pin a newer release."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if "pydantic" in tb.lower() or "validation error" in err_msg.lower():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"pydantic_validation_error",
|
||||
"Pydantic validation failed while loading your crew",
|
||||
detail=err_msg[:800],
|
||||
hint=(
|
||||
"Check agent/task configuration fields. `crewai run` locally "
|
||||
"will show the full traceback."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_failed",
|
||||
f"Importing your crew failed: {err_type}",
|
||||
detail=err_msg[:800],
|
||||
hint="Run `crewai run` locally to see the full traceback.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_missing_api_key(err_msg: str) -> str | None:
|
||||
"""Pull 'FOO_API_KEY' out of '... FOO_API_KEY is required ...'."""
|
||||
m = re.search(r"([A-Z][A-Z0-9_]*_API_KEY)\s+is required", err_msg)
|
||||
if m:
|
||||
return m.group(1)
|
||||
m = re.search(r"['\"]([A-Z][A-Z0-9_]*_API_KEY)['\"]", err_msg)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return None
|
||||
|
||||
def _check_env_vars(self) -> None:
|
||||
"""Warn about env vars referenced in user code but missing locally.
|
||||
Best-effort only — the platform sets vars server-side, so we never error.
|
||||
"""
|
||||
if not self._package_dir:
|
||||
return
|
||||
|
||||
referenced: set[str] = set()
|
||||
pattern = re.compile(
|
||||
r"""(?x)
|
||||
(?:os\.environ\s*(?:\[\s*|\.get\s*\(\s*)
|
||||
|os\.getenv\s*\(\s*
|
||||
|getenv\s*\(\s*)
|
||||
['"]([A-Z][A-Z0-9_]*)['"]
|
||||
"""
|
||||
)
|
||||
|
||||
for path in self._package_dir.rglob("*.py"):
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
continue
|
||||
referenced.update(pattern.findall(text))
|
||||
|
||||
for path in self._package_dir.rglob("*.yaml"):
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
continue
|
||||
referenced.update(re.findall(r"\$\{?([A-Z][A-Z0-9_]+)\}?", text))
|
||||
|
||||
env_file = self.project_root / ".env"
|
||||
env_keys: set[str] = set()
|
||||
if env_file.exists():
|
||||
for line in env_file.read_text(errors="ignore").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
env_keys.add(line.split("=", 1)[0].strip())
|
||||
|
||||
missing_known: list[str] = sorted(
|
||||
var
|
||||
for var in referenced
|
||||
if var in _KNOWN_API_KEY_HINTS
|
||||
and var not in env_keys
|
||||
and var not in os.environ
|
||||
)
|
||||
if missing_known:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"env_vars_not_in_dotenv",
|
||||
f"{len(missing_known)} referenced API key(s) not in .env",
|
||||
detail=(
|
||||
"These env vars are referenced in your source but not set "
|
||||
f"locally: {', '.join(missing_known)}. Deploys will fail "
|
||||
"unless they are added to the deployment's Environment "
|
||||
"Variables in the CrewAI dashboard."
|
||||
),
|
||||
)
|
||||
|
||||
def _check_version_vs_lockfile(self) -> None:
|
||||
"""Warn when the lockfile pins a crewai release older than 1.13.0,
|
||||
which is where ``_load_response_format`` was introduced.
|
||||
"""
|
||||
uv_lock = self.project_root / "uv.lock"
|
||||
poetry_lock = self.project_root / "poetry.lock"
|
||||
lockfile = (
|
||||
uv_lock
|
||||
if uv_lock.exists()
|
||||
else poetry_lock
|
||||
if poetry_lock.exists()
|
||||
else None
|
||||
)
|
||||
if lockfile is None:
|
||||
return
|
||||
|
||||
try:
|
||||
text = lockfile.read_text(errors="ignore")
|
||||
except OSError:
|
||||
return
|
||||
|
||||
m = re.search(
|
||||
r'name\s*=\s*"crewai"\s*\nversion\s*=\s*"([^"]+)"',
|
||||
text,
|
||||
)
|
||||
if not m:
|
||||
return
|
||||
locked = m.group(1)
|
||||
|
||||
try:
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(locked) < Version("1.13.0"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"old_crewai_pin",
|
||||
f"Lockfile pins crewai=={locked} (older than 1.13.0)",
|
||||
detail=(
|
||||
"Older pinned versions are missing API surface the "
|
||||
"platform builder expects (e.g. `_load_response_format`)."
|
||||
),
|
||||
hint="Run `uv lock --upgrade-package crewai` and redeploy.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not parse crewai pin from lockfile: %s", e)
|
||||
|
||||
|
||||
def render_report(results: list[ValidationResult]) -> None:
|
||||
"""Pretty-print results to the shared rich console."""
|
||||
if not results:
|
||||
console.print("[bold green]Pre-deploy validation passed.[/bold green]")
|
||||
return
|
||||
|
||||
errors = [r for r in results if r.severity is Severity.ERROR]
|
||||
warnings = [r for r in results if r.severity is Severity.WARNING]
|
||||
|
||||
for result in errors:
|
||||
console.print(f"[bold red]ERROR[/bold red] [{result.code}] {result.title}")
|
||||
if result.detail:
|
||||
console.print(f" {result.detail}")
|
||||
if result.hint:
|
||||
console.print(f" [dim]hint:[/dim] {result.hint}")
|
||||
|
||||
for result in warnings:
|
||||
console.print(
|
||||
f"[bold yellow]WARNING[/bold yellow] [{result.code}] {result.title}"
|
||||
)
|
||||
if result.detail:
|
||||
console.print(f" {result.detail}")
|
||||
if result.hint:
|
||||
console.print(f" [dim]hint:[/dim] {result.hint}")
|
||||
|
||||
summary_parts: list[str] = []
|
||||
if errors:
|
||||
summary_parts.append(f"[bold red]{len(errors)} error(s)[/bold red]")
|
||||
if warnings:
|
||||
summary_parts.append(f"[bold yellow]{len(warnings)} warning(s)[/bold yellow]")
|
||||
console.print(f"\n{' / '.join(summary_parts)}")
|
||||
|
||||
|
||||
def validate_project(project_root: Path | None = None) -> DeployValidator:
|
||||
"""Entrypoint: run validation, render results, return the validator.
|
||||
|
||||
The caller inspects ``validator.ok`` to decide whether to proceed with a
|
||||
deploy.
|
||||
"""
|
||||
validator = DeployValidator(project_root=project_root)
|
||||
validator.run()
|
||||
render_report(validator.results)
|
||||
return validator
|
||||
|
||||
|
||||
def run_validate_command() -> None:
|
||||
"""Implementation of `crewai deploy validate`."""
|
||||
validator = validate_project()
|
||||
if not validator.ok:
|
||||
sys.exit(1)
|
||||
@@ -7,7 +7,7 @@ from rich.console import Console
|
||||
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -6,7 +6,7 @@ import httpx
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
|
||||
250
lib/crewai/src/crewai/cli/remote_template/main.py
Normal file
250
lib/crewai/src/crewai/cli/remote_template/main.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
import zipfile
|
||||
|
||||
import click
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.cli.command import BaseCommand
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
GITHUB_ORG = "crewAIInc"
|
||||
TEMPLATE_PREFIX = "template_"
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
|
||||
BANNER = """\
|
||||
[bold white] ██████╗██████╗ ███████╗██╗ ██╗[/bold white] [bold red] █████╗ ██╗[/bold red]
|
||||
[bold white]██╔════╝██╔══██╗██╔════╝██║ ██║[/bold white] [bold red]██╔══██╗██║[/bold red]
|
||||
[bold white]██║ ██████╔╝█████╗ ██║ █╗ ██║[/bold white] [bold red]███████║██║[/bold red]
|
||||
[bold white]██║ ██╔══██╗██╔══╝ ██║███╗██║[/bold white] [bold red]██╔══██║██║[/bold red]
|
||||
[bold white]╚██████╗██║ ██║███████╗╚███╔███╔╝[/bold white] [bold red]██║ ██║██║[/bold red]
|
||||
[bold white] ╚═════╝╚═╝ ╚═╝╚══════╝ ╚══╝╚══╝[/bold white] [bold red]╚═╝ ╚═╝╚═╝[/bold red]
|
||||
[dim white]████████╗███████╗███╗ ███╗██████╗ ██╗ █████╗ ████████╗███████╗███████╗[/dim white]
|
||||
[dim white]╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██║ ██╔══██╗╚══██╔══╝██╔════╝██╔════╝[/dim white]
|
||||
[dim white] ██║ █████╗ ██╔████╔██║██████╔╝██║ ███████║ ██║ █████╗ ███████╗[/dim white]
|
||||
[dim white] ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██╔══██║ ██║ ██╔══╝ ╚════██║[/dim white]
|
||||
[dim white] ██║ ███████╗██║ ╚═╝ ██║██║ ███████╗██║ ██║ ██║ ███████╗███████║[/dim white]
|
||||
[dim white] ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚══════╝[/dim white]"""
|
||||
|
||||
|
||||
class TemplateCommand(BaseCommand):
|
||||
"""Handle template-related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def list_templates(self) -> None:
|
||||
"""List available templates with an interactive selector to install."""
|
||||
templates = self._fetch_templates()
|
||||
if not templates:
|
||||
click.echo("No templates found.")
|
||||
return
|
||||
|
||||
console.print(f"\n{BANNER}\n")
|
||||
console.print(" [on cyan] templates [/on cyan]\n")
|
||||
console.print(f" [green]o[/green] Source: https://github.com/{GITHUB_ORG}")
|
||||
console.print(
|
||||
f" [green]o[/green] Found [bold]{len(templates)}[/bold] templates\n"
|
||||
)
|
||||
console.print(" [green]o[/green] Select a template to install")
|
||||
|
||||
for idx, repo in enumerate(templates, start=1):
|
||||
name = repo["name"].removeprefix(TEMPLATE_PREFIX)
|
||||
description = repo.get("description") or ""
|
||||
if description:
|
||||
console.print(
|
||||
f" [bold cyan]{idx}.[/bold cyan] [bold white]{name}[/bold white] [dim]({description})[/dim]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f" [bold cyan]{idx}.[/bold cyan] [bold white]{name}[/bold white]"
|
||||
)
|
||||
|
||||
console.print(" [bold cyan]q.[/bold cyan] [dim]Quit[/dim]\n")
|
||||
|
||||
while True:
|
||||
choice = click.prompt("Enter your choice", type=str)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return
|
||||
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(templates):
|
||||
selected_index = int(choice) - 1
|
||||
break
|
||||
|
||||
click.secho(
|
||||
f"Please enter a number between 1 and {len(templates)}, or 'q' to quit.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
selected = templates[selected_index]
|
||||
repo_name = selected["name"]
|
||||
self._install_repo(repo_name)
|
||||
|
||||
def add_template(self, name: str, output_dir: str | None = None) -> None:
|
||||
"""Download a template and copy it into the current working directory.
|
||||
|
||||
Args:
|
||||
name: Template name (with or without the template_ prefix).
|
||||
output_dir: Optional directory name. Defaults to the template name.
|
||||
"""
|
||||
repo_name = self._resolve_repo_name(name)
|
||||
if repo_name is None:
|
||||
click.secho(f"Template '{name}' not found.", fg="red")
|
||||
click.echo("Run 'crewai template list' to see available templates.")
|
||||
raise SystemExit(1)
|
||||
|
||||
self._install_repo(repo_name, output_dir)
|
||||
|
||||
def _install_repo(self, repo_name: str, output_dir: str | None = None) -> None:
|
||||
"""Download and extract a template repo into the current directory.
|
||||
|
||||
Args:
|
||||
repo_name: Full GitHub repo name (e.g. template_deep_research).
|
||||
output_dir: Optional directory name. Defaults to the template name.
|
||||
"""
|
||||
folder_name = output_dir or repo_name.removeprefix(TEMPLATE_PREFIX)
|
||||
dest = os.path.join(os.getcwd(), folder_name)
|
||||
|
||||
while os.path.exists(dest):
|
||||
click.secho(f"Directory '{folder_name}' already exists.", fg="yellow")
|
||||
folder_name = click.prompt(
|
||||
"Enter a different directory name (or 'q' to quit)", type=str
|
||||
)
|
||||
if folder_name.lower() == "q":
|
||||
return
|
||||
dest = os.path.join(os.getcwd(), folder_name)
|
||||
|
||||
click.echo(
|
||||
f"Downloading template '{repo_name.removeprefix(TEMPLATE_PREFIX)}'..."
|
||||
)
|
||||
|
||||
zip_bytes = self._download_zip(repo_name)
|
||||
self._extract_zip(zip_bytes, dest)
|
||||
|
||||
self._telemetry.template_installed_span(repo_name.removeprefix(TEMPLATE_PREFIX))
|
||||
|
||||
console.print(
|
||||
f"\n [green]\u2713[/green] Installed template [bold white]{folder_name}[/bold white]"
|
||||
f" [dim](source: github.com/{GITHUB_ORG}/{repo_name})[/dim]\n"
|
||||
)
|
||||
|
||||
next_steps = Text()
|
||||
next_steps.append(f" cd {folder_name}\n", style="bold white")
|
||||
next_steps.append(" crewai install", style="bold white")
|
||||
|
||||
panel = Panel(
|
||||
next_steps,
|
||||
title="[green]\u25c7 Next steps[/green]",
|
||||
title_align="left",
|
||||
border_style="dim",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
def _fetch_templates(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all template repos from the GitHub org."""
|
||||
templates: list[dict[str, Any]] = []
|
||||
page = 1
|
||||
while True:
|
||||
url = f"{GITHUB_API_BASE}/orgs/{GITHUB_ORG}/repos"
|
||||
params: dict[str, str | int] = {
|
||||
"per_page": 100,
|
||||
"page": page,
|
||||
"type": "public",
|
||||
}
|
||||
try:
|
||||
response = httpx.get(url, params=params, timeout=15)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Failed to fetch templates from GitHub: {e}", fg="red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
repos = response.json()
|
||||
if not repos:
|
||||
break
|
||||
|
||||
templates.extend(
|
||||
repo
|
||||
for repo in repos
|
||||
if repo["name"].startswith(TEMPLATE_PREFIX) and not repo.get("private")
|
||||
)
|
||||
|
||||
page += 1
|
||||
|
||||
templates.sort(key=lambda r: r["name"])
|
||||
return templates
|
||||
|
||||
def _resolve_repo_name(self, name: str) -> str | None:
|
||||
"""Resolve user input to a full repo name, or None if not found."""
|
||||
# Accept both 'deep_research' and 'template_deep_research'
|
||||
candidates = [
|
||||
f"{TEMPLATE_PREFIX}{name}"
|
||||
if not name.startswith(TEMPLATE_PREFIX)
|
||||
else name,
|
||||
name,
|
||||
]
|
||||
|
||||
templates = self._fetch_templates()
|
||||
template_names = {t["name"] for t in templates}
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate in template_names:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _download_zip(self, repo_name: str) -> bytes:
|
||||
"""Download the default branch zipball for a repo."""
|
||||
url = f"{GITHUB_API_BASE}/repos/{GITHUB_ORG}/{repo_name}/zipball"
|
||||
try:
|
||||
response = httpx.get(url, follow_redirects=True, timeout=60)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Failed to download template: {e}", fg="red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
return response.content
|
||||
|
||||
def _extract_zip(self, zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract a GitHub zipball into dest, stripping the top-level directory."""
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# GitHub zipballs have a single top-level dir like 'crewAIInc-template_xxx-<sha>/'
|
||||
members = zf.namelist()
|
||||
if not members:
|
||||
click.secho("Downloaded archive is empty.", fg="red")
|
||||
raise SystemExit(1)
|
||||
|
||||
top_dir = members[0].split("/")[0] + "/"
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
||||
for member in members:
|
||||
if member == top_dir or not member.startswith(top_dir):
|
||||
continue
|
||||
|
||||
relative_path = member[len(top_dir) :]
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
target = os.path.realpath(os.path.join(dest, relative_path))
|
||||
if not target.startswith(
|
||||
os.path.realpath(dest) + os.sep
|
||||
) and target != os.path.realpath(dest):
|
||||
continue
|
||||
|
||||
if member.endswith("/"):
|
||||
os.makedirs(target, exist_ok=True)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(target), exist_ok=True)
|
||||
with zf.open(member) as src, open(target, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
@@ -5,7 +5,7 @@ import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import build_env_with_all_tool_credentials, read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
class CrewType(Enum):
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.0"
|
||||
"crewai[tools]==1.14.2rc1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.0"
|
||||
"crewai[tools]==1.14.2rc1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.0"
|
||||
"crewai[tools]==1.14.2rc1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
import importlib.metadata
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -13,6 +12,8 @@ from urllib.error import URLError
|
||||
import appdirs
|
||||
from packaging.version import InvalidVersion, Version, parse
|
||||
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_cache_file() -> Path:
|
||||
@@ -25,11 +26,6 @@ def _get_cache_file() -> Path:
|
||||
return cache_dir / "version_cache.json"
|
||||
|
||||
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI."""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
|
||||
def _is_cache_valid(cache_data: Mapping[str, Any]) -> bool:
|
||||
"""Check if the cache is still valid, less than 24 hours old."""
|
||||
if "timestamp" not in cache_data:
|
||||
|
||||
@@ -42,7 +42,6 @@ if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
@@ -104,7 +103,11 @@ from crewai.rag.types import SearchResult
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.skills.models import Skill
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
from crewai.state.checkpoint_config import (
|
||||
CheckpointConfig,
|
||||
_coerce_checkpoint,
|
||||
apply_checkpoint,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -134,6 +137,7 @@ from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -364,28 +368,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Crew:
|
||||
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
|
||||
def from_checkpoint(cls, config: CheckpointConfig) -> Crew:
|
||||
"""Restore a Crew from a checkpoint, ready to resume via kickoff().
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint JSON file.
|
||||
provider: Storage backend to read from. Defaults to JsonProvider.
|
||||
config: Checkpoint configuration with ``restore_from`` set to
|
||||
the path of the checkpoint to load.
|
||||
|
||||
Returns:
|
||||
A Crew instance. Call kickoff() to resume from the last completed task.
|
||||
"""
|
||||
from crewai.context import apply_execution_context
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
state = RuntimeState.from_checkpoint(config, context={"from_checkpoint": True})
|
||||
crewai_event_bus.set_runtime_state(state)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
@@ -393,7 +390,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
apply_execution_context(entity.execution_context)
|
||||
entity._restore_runtime()
|
||||
return entity
|
||||
raise ValueError(f"No Crew found in checkpoint: {path}")
|
||||
raise ValueError(f"No Crew found in checkpoint: {config.restore_from}")
|
||||
|
||||
@classmethod
|
||||
def fork(
|
||||
cls,
|
||||
config: CheckpointConfig,
|
||||
branch: str | None = None,
|
||||
) -> Crew:
|
||||
"""Fork a Crew from a checkpoint, creating a new execution branch.
|
||||
|
||||
Args:
|
||||
config: Checkpoint configuration with ``restore_from`` set.
|
||||
branch: Branch label for the fork. Auto-generated if not provided.
|
||||
|
||||
Returns:
|
||||
A Crew instance on the new branch. Call kickoff() to run.
|
||||
"""
|
||||
crew = cls.from_checkpoint(config)
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot fork: no runtime state on the event bus. "
|
||||
"Ensure from_checkpoint() succeeded before calling fork()."
|
||||
)
|
||||
state.fork(branch)
|
||||
return crew
|
||||
|
||||
def _restore_runtime(self) -> None:
|
||||
"""Re-create runtime objects after restoring from a checkpoint."""
|
||||
@@ -414,6 +436,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent.agent_executor is not None and task.output is None:
|
||||
agent.agent_executor.task = task
|
||||
break
|
||||
for task in self.tasks:
|
||||
if task.checkpoint_original_description is not None:
|
||||
task._original_description = task.checkpoint_original_description
|
||||
if task.checkpoint_original_expected_output is not None:
|
||||
task._original_expected_output = (
|
||||
task.checkpoint_original_expected_output
|
||||
)
|
||||
if self.checkpoint_inputs is not None:
|
||||
self._inputs = self.checkpoint_inputs
|
||||
if self.checkpoint_kickoff_event_id is not None:
|
||||
@@ -849,16 +878,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Execute the crew's workflow.
|
||||
|
||||
Args:
|
||||
inputs: Optional input dictionary for task interpolation.
|
||||
input_files: Optional dict of named file inputs for the crew.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the crew resumes from that checkpoint. Remaining
|
||||
config fields enable checkpointing for the run.
|
||||
|
||||
Returns:
|
||||
CrewOutput or CrewStreamingOutput if streaming is enabled.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return restored.kickoff(inputs=inputs, input_files=input_files) # type: ignore[no-any-return]
|
||||
get_env_context()
|
||||
if self.stream:
|
||||
enable_agent_streaming(self.agents)
|
||||
@@ -882,6 +918,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
@@ -970,12 +1007,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution.
|
||||
|
||||
Args:
|
||||
inputs: Optional input dictionary for task interpolation.
|
||||
input_files: Optional dict of named file inputs for the crew.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the crew resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
CrewOutput or CrewStreamingOutput if streaming is enabled.
|
||||
@@ -984,6 +1024,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
to get stream chunks. After iteration completes, access the final result
|
||||
via .result.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return await restored.kickoff_async(inputs=inputs, input_files=input_files) # type: ignore[no-any-return]
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
@@ -1007,6 +1050,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
@@ -1043,6 +1087,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
@@ -1053,10 +1098,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Args:
|
||||
inputs: Optional input dictionary for task interpolation.
|
||||
input_files: Optional dict of named file inputs for the crew.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the crew resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
CrewOutput or CrewStreamingOutput if streaming is enabled.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return await restored.akickoff(inputs=inputs, input_files=input_files) # type: ignore[no-any-return]
|
||||
if self.stream:
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
@@ -1078,6 +1128,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -431,6 +431,7 @@ async def run_for_each_async(
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -480,6 +481,7 @@ async def run_for_each_async(
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -13,13 +13,13 @@ from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
get_user_id,
|
||||
is_tracing_enabled_in_context,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,6 @@ import uuid
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
@@ -127,6 +126,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
@@ -91,7 +91,7 @@ from crewai.utilities.agent_utils import (
|
||||
track_delegation_if_needed,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.planning_types import (
|
||||
PlanStep,
|
||||
StepObservation,
|
||||
@@ -189,7 +189,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
)
|
||||
callbacks: list[Any] = Field(default_factory=list, exclude=True)
|
||||
response_model: type[BaseModel] | None = Field(default=None, exclude=True)
|
||||
i18n: I18N | None = Field(default=None, exclude=True)
|
||||
log_error_after: int = Field(default=3, exclude=True)
|
||||
before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = (
|
||||
Field(default_factory=list, exclude=True)
|
||||
@@ -198,7 +197,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
default_factory=list, exclude=True
|
||||
)
|
||||
|
||||
_i18n: I18N = PrivateAttr(default_factory=get_i18n)
|
||||
_console: Console = PrivateAttr(default_factory=Console)
|
||||
_last_parser_error: OutputParserError | None = PrivateAttr(default=None)
|
||||
_last_context_error: Exception | None = PrivateAttr(default=None)
|
||||
@@ -214,7 +212,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
@model_validator(mode="after")
|
||||
def _setup_executor(self) -> Self:
|
||||
"""Configure executor after Pydantic field initialization."""
|
||||
self._i18n = self.i18n or get_i18n()
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
|
||||
@@ -363,7 +360,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
request_within_rpm_limit=self.request_within_rpm_limit,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
return self._step_executor
|
||||
|
||||
@@ -1203,7 +1199,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer=None,
|
||||
printer=PRINTER,
|
||||
i18n=self._i18n,
|
||||
messages=list(self.state.messages),
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
@@ -1430,7 +1425,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
agent_action=action,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
@@ -1450,7 +1444,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
action.result = str(e)
|
||||
self._append_message_to_state(action.text)
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_prompt = I18N_DEFAULT.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
@@ -1471,7 +1465,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
self.state.is_finished = True
|
||||
return "tool_result_is_final"
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_prompt = I18N_DEFAULT.slice("post_tool_reasoning")
|
||||
reasoning_message_post: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
@@ -2222,10 +2216,10 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
# Build synthesis prompt
|
||||
role = self.agent.role if self.agent else "Assistant"
|
||||
|
||||
system_prompt = self._i18n.retrieve(
|
||||
system_prompt = I18N_DEFAULT.retrieve(
|
||||
"planning", "synthesis_system_prompt"
|
||||
).format(role=role)
|
||||
user_prompt = self._i18n.retrieve("planning", "synthesis_user_prompt").format(
|
||||
user_prompt = I18N_DEFAULT.retrieve("planning", "synthesis_user_prompt").format(
|
||||
task_description=task_description,
|
||||
combined_steps=combined_steps,
|
||||
)
|
||||
@@ -2472,7 +2466,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
self.task.description if self.task else getattr(self, "_kickoff_input", "")
|
||||
)
|
||||
|
||||
enhancement = self._i18n.retrieve(
|
||||
enhancement = I18N_DEFAULT.retrieve(
|
||||
"planning", "replan_enhancement_prompt"
|
||||
).format(previous_context=previous_context)
|
||||
|
||||
@@ -2535,7 +2529,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
messages=self.state.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
@@ -2746,7 +2739,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignor
|
||||
Returns:
|
||||
Updated action or final answer.
|
||||
"""
|
||||
add_image_tool = self._i18n.tools("add_image")
|
||||
add_image_tool = I18N_DEFAULT.tools("add_image")
|
||||
if (
|
||||
isinstance(add_image_tool, dict)
|
||||
and formatted_answer.tool.casefold().strip()
|
||||
|
||||
@@ -113,7 +113,11 @@ from crewai.flow.utils import (
|
||||
)
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
from crewai.state.checkpoint_config import (
|
||||
CheckpointConfig,
|
||||
_coerce_checkpoint,
|
||||
apply_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -122,7 +126,6 @@ if TYPE_CHECKING:
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
@@ -132,6 +135,7 @@ from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -927,20 +931,21 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
] = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Flow: # type: ignore[type-arg]
|
||||
"""Restore a Flow from a checkpoint file."""
|
||||
def from_checkpoint(cls, config: CheckpointConfig) -> Flow: # type: ignore[type-arg]
|
||||
"""Restore a Flow from a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Checkpoint configuration with ``restore_from`` set to
|
||||
the path of the checkpoint to load.
|
||||
|
||||
Returns:
|
||||
A Flow instance ready to resume.
|
||||
"""
|
||||
from crewai.context import apply_execution_context
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
state = RuntimeState.from_checkpoint(config, context={"from_checkpoint": True})
|
||||
crewai_event_bus.set_runtime_state(state)
|
||||
for entity in state.root:
|
||||
if not isinstance(entity, Flow):
|
||||
@@ -957,7 +962,32 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
instance.checkpoint_state = entity.checkpoint_state
|
||||
instance._restore_from_checkpoint()
|
||||
return instance
|
||||
raise ValueError(f"No Flow found in checkpoint: {path}")
|
||||
raise ValueError(f"No Flow found in checkpoint: {config.restore_from}")
|
||||
|
||||
@classmethod
|
||||
def fork(
|
||||
cls,
|
||||
config: CheckpointConfig,
|
||||
branch: str | None = None,
|
||||
) -> Flow: # type: ignore[type-arg]
|
||||
"""Fork a Flow from a checkpoint, creating a new execution branch.
|
||||
|
||||
Args:
|
||||
config: Checkpoint configuration with ``restore_from`` set.
|
||||
branch: Branch label for the fork. Auto-generated if not provided.
|
||||
|
||||
Returns:
|
||||
A Flow instance on the new branch. Call kickoff() to run.
|
||||
"""
|
||||
flow = cls.from_checkpoint(config)
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot fork: no runtime state on the event bus. "
|
||||
"Ensure from_checkpoint() succeeded before calling fork()."
|
||||
)
|
||||
state.fork(branch)
|
||||
return flow
|
||||
|
||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
||||
@@ -1454,6 +1484,25 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"No pending feedback context. Use from_pending() to restore a paused flow."
|
||||
)
|
||||
|
||||
if get_current_parent_id() is None:
|
||||
reset_emission_counter()
|
||||
reset_last_event_id()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
inputs=None,
|
||||
),
|
||||
)
|
||||
if future and isinstance(future, Future):
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowStartedEvent handler failed", exc_info=True)
|
||||
|
||||
context = self._pending_feedback_context
|
||||
emit = context.emit
|
||||
default_outcome = context.default_outcome
|
||||
@@ -1593,16 +1642,39 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
final_result = self._method_outputs[-1] if self._method_outputs else result
|
||||
|
||||
# Emit flow finished
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_result,
|
||||
state=self._state,
|
||||
),
|
||||
)
|
||||
if self._event_futures:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
asyncio.wrap_future(f)
|
||||
for f in self._event_futures
|
||||
if isinstance(f, Future)
|
||||
]
|
||||
)
|
||||
self._event_futures.clear()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_result,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future and isinstance(future, Future):
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowFinishedEvent handler failed", exc_info=True)
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
|
||||
return final_result
|
||||
|
||||
@@ -1913,6 +1985,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Start the flow execution in a synchronous context.
|
||||
|
||||
@@ -1922,10 +1995,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID.
|
||||
input_files: Optional dict of named file inputs for the flow.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the flow resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
The final output from the flow or FlowStreamingOutput if streaming.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return restored.kickoff(inputs=inputs, input_files=input_files)
|
||||
get_env_context()
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
@@ -1962,6 +2040,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
streaming_output = FlowStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_flow, output_holder)
|
||||
)
|
||||
register_cleanup(streaming_output, state)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
@@ -1981,6 +2060,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Start the flow execution asynchronously.
|
||||
|
||||
@@ -1992,10 +2072,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
input_files: Optional dict of named file inputs for the flow.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the flow resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return await restored.kickoff_async(inputs=inputs, input_files=input_files)
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
@@ -2035,6 +2120,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
state, run_flow, output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, state)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
@@ -2052,7 +2138,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
try:
|
||||
# Reset flow state for fresh execution unless restoring from persistence
|
||||
is_restoring = inputs and "id" in inputs and self.persistence is not None
|
||||
is_restoring = (
|
||||
inputs and "id" in inputs and self.persistence is not None
|
||||
) or self.checkpoint_completed_methods is not None
|
||||
if not is_restoring:
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
@@ -2253,17 +2341,20 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
input_files: Optional dict of named file inputs for the flow.
|
||||
from_checkpoint: Optional checkpoint config. If ``restore_from``
|
||||
is set, the flow resumes from that checkpoint.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
return await self.kickoff_async(inputs, input_files)
|
||||
return await self.kickoff_async(inputs, input_files, from_checkpoint)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
@@ -3191,7 +3282,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
llm_instance: BaseLLMClass
|
||||
if isinstance(llm, str):
|
||||
@@ -3211,9 +3302,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
description=f"The outcome that best matches the feedback. Must be one of: {', '.join(outcomes)}"
|
||||
)
|
||||
|
||||
# Load prompt from translations (using cached instance)
|
||||
i18n = get_i18n()
|
||||
prompt_template = i18n.slice("human_feedback_collapse")
|
||||
prompt_template = I18N_DEFAULT.slice("human_feedback_collapse")
|
||||
|
||||
prompt = prompt_template.format(
|
||||
feedback=feedback,
|
||||
|
||||
@@ -350,9 +350,9 @@ def human_feedback(
|
||||
|
||||
def _get_hitl_prompt(key: str) -> str:
|
||||
"""Read a HITL prompt from the i18n translations."""
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
return get_i18n().slice(key)
|
||||
return I18N_DEFAULT.slice(key)
|
||||
|
||||
def _resolve_llm_instance() -> Any:
|
||||
"""Resolve the ``llm`` parameter to a BaseLLM instance.
|
||||
|
||||
@@ -5,6 +5,8 @@ from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
@@ -37,6 +39,9 @@ def _create_hook_decorator(
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
if tools:
|
||||
tools = [sanitize_tool_name(t) for t in tools]
|
||||
|
||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
setattr(f, marker_attribute, True)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -26,7 +25,7 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -89,7 +88,7 @@ from crewai.utilities.converter import (
|
||||
)
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.guardrail_types import GuardrailCallable, GuardrailType
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
@@ -173,9 +172,12 @@ def _kickoff_with_a2a_support(
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"LiteAgent is deprecated and will be removed in v2.0.0.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
"""A lightweight agent that can process messages and use tools.
|
||||
|
||||
.. deprecated::
|
||||
LiteAgent is deprecated and will be removed in a future version.
|
||||
@@ -227,9 +229,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Callback to check if the request is within the RPM8 limit",
|
||||
)
|
||||
i18n: I18N = Field(
|
||||
default_factory=get_i18n, description="Internationalization settings."
|
||||
)
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
@@ -281,18 +280,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
_memory: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def emit_deprecation_warning(self) -> Self:
|
||||
"""Emit deprecation warning for LiteAgent usage."""
|
||||
warnings.warn(
|
||||
"LiteAgent is deprecated and will be removed in a future version. "
|
||||
"Use Agent().kickoff(messages) instead, which provides the same "
|
||||
"functionality with additional features like memory and knowledge support.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self) -> Self:
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
@@ -571,7 +558,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
f"- {m.record.content}" for m in matches
|
||||
)
|
||||
if memory_block:
|
||||
formatted = self.i18n.slice("memory").format(memory=memory_block)
|
||||
formatted = I18N_DEFAULT.slice("memory").format(memory=memory_block)
|
||||
if self._messages and self._messages[0].get("role") == "system":
|
||||
existing_content = self._messages[0].get("content", "")
|
||||
if not isinstance(existing_content, str):
|
||||
@@ -644,7 +631,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
try:
|
||||
model_schema = generate_model_description(active_response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = self.i18n.slice("formatted_task_instructions").format(
|
||||
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
@@ -793,7 +780,9 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
base_prompt = ""
|
||||
if self._parsed_tools:
|
||||
# Use the prompt template for agents with tools
|
||||
base_prompt = self.i18n.slice("lite_agent_system_prompt_with_tools").format(
|
||||
base_prompt = I18N_DEFAULT.slice(
|
||||
"lite_agent_system_prompt_with_tools"
|
||||
).format(
|
||||
role=self.role,
|
||||
backstory=self.backstory,
|
||||
goal=self.goal,
|
||||
@@ -802,7 +791,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
else:
|
||||
# Use the prompt template for agents without tools
|
||||
base_prompt = self.i18n.slice(
|
||||
base_prompt = I18N_DEFAULT.slice(
|
||||
"lite_agent_system_prompt_without_tools"
|
||||
).format(
|
||||
role=self.role,
|
||||
@@ -814,7 +803,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if active_response_format:
|
||||
model_description = generate_model_description(active_response_format)
|
||||
schema_json = json.dumps(model_description, indent=2)
|
||||
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
||||
base_prompt += I18N_DEFAULT.slice("lite_agent_response_format").format(
|
||||
response_format=schema_json
|
||||
)
|
||||
|
||||
@@ -875,7 +864,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=PRINTER,
|
||||
i18n=self.i18n,
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
callbacks=self._callbacks,
|
||||
@@ -914,7 +902,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
tool_result = execute_tool_and_check_finality(
|
||||
agent_action=formatted_answer,
|
||||
tools=self._parsed_tools,
|
||||
i18n=self.i18n,
|
||||
agent_key=self.key,
|
||||
agent_role=self.role,
|
||||
agent=self.original_agent,
|
||||
@@ -956,7 +943,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
callbacks=self._callbacks,
|
||||
i18n=self.i18n,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -51,6 +51,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
|
||||
|
||||
try:
|
||||
@@ -75,8 +76,13 @@ try:
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Delta as LiteLLMDelta,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseBase,
|
||||
ModelResponseStream,
|
||||
StreamingChoices as LiteLLMStreamingChoices,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
@@ -85,6 +91,11 @@ except ImportError:
|
||||
LITELLM_AVAILABLE = False
|
||||
litellm = None # type: ignore[assignment]
|
||||
Choices = None # type: ignore[assignment, misc]
|
||||
LiteLLMDelta = None # type: ignore[assignment, misc]
|
||||
Message = None # type: ignore[assignment, misc]
|
||||
ModelResponseBase = None # type: ignore[assignment, misc]
|
||||
ModelResponseStream = None # type: ignore[assignment, misc]
|
||||
LiteLLMStreamingChoices = None # type: ignore[assignment, misc]
|
||||
get_supported_openai_params = None # type: ignore[assignment]
|
||||
ChatCompletionDeltaToolCall = None # type: ignore[assignment, misc]
|
||||
Function = None # type: ignore[assignment, misc]
|
||||
@@ -709,7 +720,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = None
|
||||
response_id = None
|
||||
|
||||
if hasattr(chunk, "id"):
|
||||
if isinstance(chunk, ModelResponseBase):
|
||||
response_id = chunk.id
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
@@ -718,18 +729,16 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
# NOTE: usage is a pydantic extra field on ModelResponseBase,
|
||||
# so it must be accessed via model_extra.
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
elif isinstance(chunk, ModelResponseBase) and chunk.model_extra:
|
||||
usage_info = chunk.model_extra.get("usage") or usage_info
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -738,7 +747,7 @@ class LLM(BaseLLM):
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
elif isinstance(choice, LiteLLMStreamingChoices):
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
@@ -748,7 +757,7 @@ class LLM(BaseLLM):
|
||||
if "content" in delta and delta["content"] is not None:
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
@@ -821,9 +830,8 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
elif isinstance(last_chunk, ModelResponseStream):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -832,14 +840,14 @@ class LLM(BaseLLM):
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
elif isinstance(choice, Choices):
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
elif isinstance(message, Message):
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
@@ -866,24 +874,23 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
elif isinstance(last_chunk, ModelResponseStream):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = choice.message
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif isinstance(choice, LiteLLMStreamingChoices):
|
||||
delta = choice.delta
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = message.tool_calls
|
||||
if delta:
|
||||
if isinstance(delta, dict) and "tool_calls" in delta:
|
||||
tool_calls = delta["tool_calls"]
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
tool_calls = delta.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
|
||||
@@ -1037,7 +1044,7 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
# Use the usage_info we've been tracking
|
||||
if not usage_info:
|
||||
# Try to get usage from the last chunk if we haven't already
|
||||
@@ -1048,9 +1055,14 @@ class LLM(BaseLLM):
|
||||
and "usage" in last_chunk
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
elif (
|
||||
isinstance(last_chunk, ModelResponseBase)
|
||||
and last_chunk.model_extra
|
||||
):
|
||||
usage_info = (
|
||||
last_chunk.model_extra.get("usage")
|
||||
or usage_info
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -1123,13 +1135,10 @@ class LLM(BaseLLM):
|
||||
params["response_model"] = response_model
|
||||
response = litellm.completion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra:
|
||||
usage_info = response.model_extra.get("usage")
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
@@ -1141,7 +1150,11 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
response_usage = self._usage_to_dict(
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra
|
||||
else None
|
||||
)
|
||||
|
||||
# --- 2) Handle structured output response (when response_model is provided)
|
||||
if response_model is not None:
|
||||
@@ -1166,8 +1179,13 @@ class LLM(BaseLLM):
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
usage_info = (
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase)
|
||||
and response.model_extra
|
||||
else None
|
||||
)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
@@ -1176,7 +1194,7 @@ class LLM(BaseLLM):
|
||||
end_time=0,
|
||||
)
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
tool_calls = response_message.tool_calls or []
|
||||
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
@@ -1269,13 +1287,10 @@ class LLM(BaseLLM):
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra:
|
||||
usage_info = response.model_extra.get("usage")
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
@@ -1287,7 +1302,11 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
response_usage = self._usage_to_dict(
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra
|
||||
else None
|
||||
)
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
@@ -1309,8 +1328,13 @@ class LLM(BaseLLM):
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
usage_info = (
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase)
|
||||
and response.model_extra
|
||||
else None
|
||||
)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
@@ -1319,7 +1343,7 @@ class LLM(BaseLLM):
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
tool_calls = response_message.tool_calls or []
|
||||
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
self._handle_emit_call_events(
|
||||
@@ -1394,18 +1418,19 @@ class LLM(BaseLLM):
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
response_id = chunk.id if hasattr(chunk, "id") else None
|
||||
response_id = chunk.id if isinstance(chunk, ModelResponseBase) else None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
if isinstance(chunk, ModelResponseBase) and chunk.model_extra:
|
||||
chunk_usage = chunk.model_extra.get("usage")
|
||||
if chunk_usage is not None:
|
||||
usage_info = chunk_usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
first_choice = choices[0]
|
||||
@@ -1413,19 +1438,19 @@ class LLM(BaseLLM):
|
||||
|
||||
if isinstance(first_choice, dict):
|
||||
delta = first_choice.get("delta", {})
|
||||
elif hasattr(first_choice, "delta"):
|
||||
elif isinstance(first_choice, LiteLLMStreamingChoices):
|
||||
delta = first_choice.delta
|
||||
|
||||
if delta:
|
||||
if isinstance(delta, dict):
|
||||
chunk_content = delta.get("content")
|
||||
elif hasattr(delta, "content"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
chunk_content = delta.content
|
||||
|
||||
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
|
||||
if isinstance(delta, dict):
|
||||
tool_calls = delta.get("tool_calls")
|
||||
elif hasattr(delta, "tool_calls"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
tool_calls = delta.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
@@ -1461,7 +1486,7 @@ class LLM(BaseLLM):
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
@@ -1920,7 +1945,7 @@ class LLM(BaseLLM):
|
||||
return None
|
||||
if isinstance(usage, dict):
|
||||
return usage
|
||||
if hasattr(usage, "model_dump"):
|
||||
if isinstance(usage, BaseModel):
|
||||
result: dict[str, Any] = usage.model_dump()
|
||||
return result
|
||||
if hasattr(usage, "__dict__"):
|
||||
@@ -1984,7 +2009,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return messages
|
||||
|
||||
provider = getattr(self, "provider", None) or self.model
|
||||
provider = self.provider or self.model
|
||||
|
||||
for msg in messages:
|
||||
files = msg.get("files")
|
||||
@@ -2035,7 +2060,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return messages
|
||||
|
||||
provider = getattr(self, "provider", None) or self.model
|
||||
provider = self.provider or self.model
|
||||
|
||||
for msg in messages:
|
||||
files = msg.get("files")
|
||||
|
||||
@@ -172,6 +172,8 @@ class BaseLLM(BaseModel, ABC):
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -808,14 +810,24 @@ class BaseLLM(BaseModel, ABC):
|
||||
cached_tokens = (
|
||||
usage_data.get("cached_tokens")
|
||||
or usage_data.get("cached_prompt_tokens")
|
||||
or usage_data.get("cache_read_input_tokens")
|
||||
or 0
|
||||
)
|
||||
if not cached_tokens:
|
||||
prompt_details = usage_data.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = prompt_details.get("cached_tokens", 0) or 0
|
||||
|
||||
reasoning_tokens = usage_data.get("reasoning_tokens", 0) or 0
|
||||
cache_creation_tokens = usage_data.get("cache_creation_tokens", 0) or 0
|
||||
|
||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
||||
self._token_usage["completion_tokens"] += completion_tokens
|
||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
self._token_usage["reasoning_tokens"] += reasoning_tokens
|
||||
self._token_usage["cache_creation_tokens"] += cache_creation_tokens
|
||||
|
||||
def get_token_usage_summary(self) -> UsageMetrics:
|
||||
"""Get summary of token usage for this LLM instance.
|
||||
|
||||
@@ -11,10 +11,14 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
sanitize_tool_params_for_anthropic_strict,
|
||||
)
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -189,16 +193,41 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AnthropicCompletion:
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
"""Eagerly build clients when the API key is available, otherwise
|
||||
defer so ``LLM(model="anthropic/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
def _build_sync_client(self) -> Any:
|
||||
return Anthropic(**self._get_client_params())
|
||||
|
||||
def _build_async_client(self) -> Any:
|
||||
# Skip the sync httpx.Client that `_get_client_params` would
|
||||
# otherwise construct under `interceptor`; we attach an async one
|
||||
# below and would leak the sync one if both were built.
|
||||
async_client_params = self._get_client_params(include_http_client=False)
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
async_client_params["http_client"] = httpx.AsyncClient(
|
||||
transport=async_transport
|
||||
)
|
||||
return AsyncAnthropic(**async_client_params)
|
||||
|
||||
self._async_client = AsyncAnthropic(**async_client_params)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
@@ -213,8 +242,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
config["timeout"] = self.timeout
|
||||
return config
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
def _get_client_params(self, include_http_client: bool = True) -> dict[str, Any]:
|
||||
"""Get client parameters.
|
||||
|
||||
Args:
|
||||
include_http_client: When True (default) and an interceptor is
|
||||
set, attach a sync ``httpx.Client``. The async builder
|
||||
passes ``False`` so it can attach its own async client
|
||||
without leaking a sync one.
|
||||
"""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
@@ -228,7 +264,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
if self.interceptor:
|
||||
if include_http_client and self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_params["http_client"] = http_client # type: ignore[assignment]
|
||||
@@ -473,10 +509,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
@@ -485,8 +519,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
"description": description,
|
||||
}
|
||||
|
||||
func_info = tool.get("function", {})
|
||||
strict_enabled = bool(func_info.get("strict"))
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
anthropic_tool["input_schema"] = (
|
||||
sanitize_tool_params_for_anthropic_strict(parameters)
|
||||
if strict_enabled
|
||||
else parameters
|
||||
)
|
||||
else:
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
@@ -494,6 +535,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if strict_enabled:
|
||||
anthropic_tool["strict"] = True
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
@@ -786,11 +830,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self._client.beta.messages.create(
|
||||
response = self._get_sync_client().beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self._client.messages.create(**params)
|
||||
response = self._get_sync_client().messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -938,9 +982,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
self._get_sync_client().beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self._client.messages.stream(**stream_params)
|
||||
else self._get_sync_client().messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1219,7 +1265,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._get_sync_client().messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -1315,11 +1363,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self._async_client.beta.messages.create(
|
||||
response = await self._get_async_client().beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self._async_client.messages.create(**params)
|
||||
response = await self._get_async_client().messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -1453,11 +1501,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._async_client.beta.messages.stream(
|
||||
self._get_async_client().beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self._async_client.messages.stream(**stream_params)
|
||||
else self._get_async_client().messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1622,7 +1670,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self._async_client.messages.create(
|
||||
final_response: Message = await self._get_async_client().messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
@@ -1704,18 +1752,23 @@ class AnthropicCompletion(BaseLLM):
|
||||
def _extract_anthropic_token_usage(
|
||||
response: Message | BetaMessage,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
"""Extract token usage and response metadata from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
input_tokens = getattr(usage, "input_tokens", 0)
|
||||
output_tokens = getattr(usage, "output_tokens", 0)
|
||||
cache_read_tokens = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
return {
|
||||
cache_creation_tokens = (
|
||||
getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
)
|
||||
result: dict[str, Any] = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
"cached_prompt_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
}
|
||||
return result
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def supports_multimodal(self) -> bool:
|
||||
@@ -1745,8 +1798,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self._get_sync_client(),
|
||||
async_client=self._get_async_client(),
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -116,43 +116,100 @@ class AzureCompletion(BaseLLM):
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
|
||||
if not data["api_key"]:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not data["endpoint"]:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Credentials and endpoint are validated lazily in `_init_clients`
|
||||
# so the LLM can be constructed before deployment env vars are set.
|
||||
model = data.get("model", "")
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
if data["endpoint"]:
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
data["is_azure_openai_endpoint"] = AzureCompletion._is_azure_openai_endpoint(
|
||||
data["endpoint"]
|
||||
)
|
||||
data["is_openai_model"] = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_endpoint(endpoint: str | None) -> bool:
|
||||
if not endpoint:
|
||||
return False
|
||||
hostname = urlparse(endpoint).hostname or ""
|
||||
return (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in endpoint
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AzureCompletion:
|
||||
"""Eagerly build clients when credentials are available, otherwise
|
||||
defer so ``LLM(model="azure/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
return ChatCompletionsClient(**self._make_client_kwargs())
|
||||
|
||||
def _build_async_client(self) -> Any:
|
||||
return AsyncChatCompletionsClient(**self._make_client_kwargs())
|
||||
|
||||
def _make_client_kwargs(self) -> dict[str, Any]:
|
||||
# Re-read env vars so that a deferred build can pick up credentials
|
||||
# that weren't set at instantiation time (e.g. LLM constructed at
|
||||
# module import before deployment env vars were injected).
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key is required.")
|
||||
self.api_key = os.getenv("AZURE_API_KEY")
|
||||
if not self.endpoint:
|
||||
endpoint = (
|
||||
os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
if endpoint:
|
||||
self.endpoint = AzureCompletion._validate_and_fix_endpoint(
|
||||
endpoint, self.model
|
||||
)
|
||||
# Recompute the routing flag now that the endpoint is known —
|
||||
# _prepare_completion_params uses it to decide whether to
|
||||
# include `model` in the request body (Azure OpenAI endpoints
|
||||
# embed the deployment name in the URL and reject it).
|
||||
self.is_azure_openai_endpoint = (
|
||||
AzureCompletion._is_azure_openai_endpoint(self.endpoint)
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment "
|
||||
"variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment "
|
||||
"variable or pass endpoint parameter."
|
||||
)
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
return client_kwargs
|
||||
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
@@ -713,8 +770,7 @@ class AzureCompletion(BaseLLM):
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
response: ChatCompletions = self._get_sync_client().complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -913,7 +969,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
for update in self._client.complete(**params):
|
||||
for update in self._get_sync_client().complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
@@ -953,8 +1009,9 @@ class AzureCompletion(BaseLLM):
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self._async_client.complete(**params)
|
||||
response: ChatCompletions = await self._get_async_client().complete(
|
||||
**params
|
||||
)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -980,7 +1037,7 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.complete(**params)
|
||||
stream = await self._get_async_client().complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
@@ -1076,28 +1133,39 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
@staticmethod
|
||||
def _extract_azure_token_usage(response: ChatCompletions) -> dict[str, Any]:
|
||||
"""Extract token usage from Azure response."""
|
||||
"""Extract token usage and response metadata from Azure response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
cached_tokens = 0
|
||||
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
||||
if prompt_details:
|
||||
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
|
||||
return {
|
||||
reasoning_tokens = 0
|
||||
completion_details = getattr(usage, "completion_tokens_details", None)
|
||||
if completion_details:
|
||||
reasoning_tokens = (
|
||||
getattr(completion_details, "reasoning_tokens", 0) or 0
|
||||
)
|
||||
result: dict[str, Any] = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
"cached_prompt_tokens": cached_tokens,
|
||||
"reasoning_tokens": reasoning_tokens,
|
||||
}
|
||||
return result
|
||||
return {"total_tokens": 0}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the async client and clean up resources.
|
||||
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
to avoid unclosed connector warnings. Accesses the cached client
|
||||
directly rather than going through `_get_async_client` so a
|
||||
cleanup on an uninitialized LLM is a harmless no-op rather than
|
||||
a credential-required error.
|
||||
"""
|
||||
if hasattr(self._async_client, "close"):
|
||||
if self._async_client is not None and hasattr(self._async_client, "close"):
|
||||
await self._async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -302,6 +303,22 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> BedrockCompletion:
|
||||
"""Eagerly build the sync client when AWS credentials resolve,
|
||||
otherwise defer so ``LLM(model="bedrock/...")`` can be constructed
|
||||
at module import time even before deployment env vars are set.
|
||||
|
||||
Only credential/SDK errors are caught — programming errors like
|
||||
``TypeError`` or ``AttributeError`` propagate so real bugs aren't
|
||||
silently swallowed.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
except (BotoCoreError, ClientError, ValueError) as e:
|
||||
logging.debug("Deferring Bedrock client construction: %s", e)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
@@ -313,9 +330,17 @@ class BedrockCompletion(BaseLLM):
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
return self
|
||||
return session.client("bedrock-runtime", config=config)
|
||||
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Async client is set up separately by ``_ensure_async_client``
|
||||
using ``aiobotocore`` inside an exit stack."""
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
@@ -655,7 +680,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self._client.converse(
|
||||
response = self._get_sync_client().converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -944,7 +969,7 @@ class BedrockCompletion(BaseLLM):
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
try:
|
||||
response = self._client.converse_stream(
|
||||
response = self._get_sync_client().converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -1948,8 +1973,6 @@ class BedrockCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[ConverseToolTypeDef]:
|
||||
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
converse_tools: list[ConverseToolTypeDef] = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -2025,11 +2048,18 @@ class BedrockCompletion(BaseLLM):
|
||||
input_tokens = usage.get("inputTokens", 0)
|
||||
output_tokens = usage.get("outputTokens", 0)
|
||||
total_tokens = usage.get("totalTokens", input_tokens + output_tokens)
|
||||
raw_cached = (
|
||||
usage.get("cacheReadInputTokenCount")
|
||||
or usage.get("cacheReadInputTokens")
|
||||
or 0
|
||||
)
|
||||
cached_tokens = raw_cached if isinstance(raw_cached, int) else 0
|
||||
|
||||
self._token_usage["prompt_tokens"] += input_tokens
|
||||
self._token_usage["completion_tokens"] += output_tokens
|
||||
self._token_usage["total_tokens"] += total_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
|
||||
@@ -118,9 +118,33 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> GeminiCompletion:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
"""Eagerly build the client when credentials resolve, otherwise defer
|
||||
so ``LLM(model="gemini/...")`` can be constructed at module import time
|
||||
even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
# Re-read env vars so a deferred build can pick up credentials
|
||||
# that weren't set at instantiation time.
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv(
|
||||
"GEMINI_API_KEY"
|
||||
)
|
||||
if not self.project:
|
||||
self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Gemini uses a single client for both sync and async calls."""
|
||||
return self._get_sync_client()
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
@@ -228,6 +252,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and self._client is not None
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
):
|
||||
@@ -1112,7 +1137,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self._client.models.generate_content(
|
||||
response = self._get_sync_client().models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1153,7 +1178,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self._client.models.generate_content_stream(
|
||||
for chunk in self._get_sync_client().models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1191,7 +1216,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self._client.aio.models.generate_content(
|
||||
response = await self._get_async_client().aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1232,7 +1257,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self._client.aio.models.generate_content_stream(
|
||||
stream = await self._get_async_client().aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1306,17 +1331,20 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
"""Extract token usage and response metadata from Gemini response."""
|
||||
if response.usage_metadata:
|
||||
usage = response.usage_metadata
|
||||
cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
|
||||
return {
|
||||
thinking_tokens = getattr(usage, "thoughts_token_count", 0) or 0
|
||||
result: dict[str, Any] = {
|
||||
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
||||
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
|
||||
"total_token_count": getattr(usage, "total_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
"cached_prompt_tokens": cached_tokens,
|
||||
"reasoning_tokens": thinking_tokens,
|
||||
}
|
||||
return result
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@staticmethod
|
||||
@@ -1436,6 +1464,6 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(client=self._client)
|
||||
return GeminiFileUploader(client=self._get_sync_client())
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -32,11 +32,15 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
generate_model_description,
|
||||
sanitize_tool_params_for_openai_strict,
|
||||
)
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -253,22 +257,40 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> OpenAICompletion:
|
||||
"""Eagerly build clients when the API key is available, otherwise
|
||||
defer so ``LLM(model="openai/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
client_config["http_client"] = httpx.Client(transport=transport)
|
||||
return OpenAI(**client_config)
|
||||
|
||||
self._client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
def _build_async_client(self) -> Any:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
client_config["http_client"] = httpx.AsyncClient(transport=transport)
|
||||
return AsyncOpenAI(**client_config)
|
||||
|
||||
self._async_client = AsyncOpenAI(**async_client_config)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
@property
|
||||
def last_response_id(self) -> str | None:
|
||||
@@ -764,8 +786,6 @@ class OpenAICompletion(BaseLLM):
|
||||
"function": {"name": "...", "description": "...", "parameters": {...}}
|
||||
}
|
||||
"""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
responses_tools = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -797,7 +817,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = self._client.responses.create(**params)
|
||||
response: Response = self._get_sync_client().responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -933,7 +953,9 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle async non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = await self._async_client.responses.create(**params)
|
||||
response: Response = await self._get_async_client().responses.create(
|
||||
**params
|
||||
)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -1069,7 +1091,7 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = self._client.responses.create(**params)
|
||||
stream = self._get_sync_client().responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
@@ -1197,7 +1219,7 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
stream = await self._get_async_client().responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
@@ -1324,19 +1346,23 @@ class OpenAICompletion(BaseLLM):
|
||||
]
|
||||
|
||||
def _extract_responses_token_usage(self, response: Response) -> dict[str, Any]:
|
||||
"""Extract token usage from Responses API response."""
|
||||
"""Extract token usage and response metadata from Responses API response."""
|
||||
if response.usage:
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
# Extract cached prompt tokens from input_tokens_details
|
||||
input_details = getattr(response.usage, "input_tokens_details", None)
|
||||
if input_details:
|
||||
result["cached_prompt_tokens"] = (
|
||||
getattr(input_details, "cached_tokens", 0) or 0
|
||||
)
|
||||
output_details = getattr(response.usage, "output_tokens_details", None)
|
||||
if output_details:
|
||||
result["reasoning_tokens"] = (
|
||||
getattr(output_details, "reasoning_tokens", 0) or 0
|
||||
)
|
||||
return result
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@@ -1544,11 +1570,6 @@ class OpenAICompletion(BaseLLM):
|
||||
self, tools: list[dict[str, BaseTool]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert CrewAI tool format to OpenAI function calling format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
force_additional_properties_false,
|
||||
)
|
||||
|
||||
openai_tools = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -1567,8 +1588,9 @@ class OpenAICompletion(BaseLLM):
|
||||
params_dict = (
|
||||
parameters if isinstance(parameters, dict) else dict(parameters)
|
||||
)
|
||||
params_dict = force_additional_properties_false(params_dict)
|
||||
openai_tool["function"]["parameters"] = params_dict
|
||||
openai_tool["function"]["parameters"] = (
|
||||
sanitize_tool_params_for_openai_strict(params_dict)
|
||||
)
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
@@ -1587,7 +1609,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
parsed_response = self._get_sync_client().beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1611,7 +1633,9 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._get_sync_client().chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -1838,7 +1862,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self._client.beta.chat.completions.stream(
|
||||
with self._get_sync_client().beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
@@ -1875,7 +1899,7 @@ class OpenAICompletion(BaseLLM):
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self._client.chat.completions.create(**params)
|
||||
self._get_sync_client().chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
@@ -1972,9 +1996,11 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self._async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
parsed_response = (
|
||||
await self._get_async_client().beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
@@ -1996,8 +2022,8 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self._async_client.chat.completions.create(
|
||||
**params
|
||||
response: ChatCompletion = (
|
||||
await self._get_async_client().chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
@@ -2123,7 +2149,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self._get_async_client().chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
@@ -2179,7 +2205,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self._get_async_client().chat.completions.create(**params)
|
||||
|
||||
usage_data = None
|
||||
|
||||
@@ -2307,20 +2333,24 @@ class OpenAICompletion(BaseLLM):
|
||||
def _extract_openai_token_usage(
|
||||
self, response: ChatCompletion | ChatCompletionChunk
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion or ChatCompletionChunk response."""
|
||||
"""Extract token usage and response metadata from OpenAI ChatCompletion."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
# Extract cached prompt tokens from prompt_tokens_details
|
||||
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
||||
if prompt_details:
|
||||
result["cached_prompt_tokens"] = (
|
||||
getattr(prompt_details, "cached_tokens", 0) or 0
|
||||
)
|
||||
completion_details = getattr(usage, "completion_tokens_details", None)
|
||||
if completion_details:
|
||||
result["reasoning_tokens"] = (
|
||||
getattr(completion_details, "reasoning_tokens", 0) or 0
|
||||
)
|
||||
return result
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@@ -2371,8 +2401,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self._get_sync_client(),
|
||||
async_client=self._get_async_client(),
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -417,9 +417,18 @@ class MCPToolResolver:
|
||||
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
try:
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Failed to build args schema for MCP tool "
|
||||
f"'{tool_name}': {e}. Registering tool without a "
|
||||
"typed schema.",
|
||||
)
|
||||
args_schema = None
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.memory.types import MemoryPromptConfig, MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -149,7 +149,7 @@ def _memory_prompt_line(
|
||||
raw = getattr(memory_prompt, key, None)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw
|
||||
return get_i18n().memory(key)
|
||||
return I18N_DEFAULT.memory(key)
|
||||
|
||||
|
||||
def extract_memories_from_content(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -201,11 +202,20 @@ class CheckpointConfig(BaseModel):
|
||||
description="Maximum checkpoints to keep. Oldest are pruned after "
|
||||
"each write. None means keep all.",
|
||||
)
|
||||
restore_from: Path | str | None = Field(
|
||||
default=None,
|
||||
description="Path or location of a checkpoint to restore from. "
|
||||
"When passed via a kickoff method's from_checkpoint parameter, "
|
||||
"the crew or flow resumes from this checkpoint.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _register_handlers(self) -> CheckpointConfig:
|
||||
from crewai.state.checkpoint_listener import _ensure_handlers_registered
|
||||
|
||||
if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix:
|
||||
self.location = f"{self.location}.db"
|
||||
|
||||
_ensure_handlers_registered()
|
||||
return self
|
||||
|
||||
@@ -216,3 +226,25 @@ class CheckpointConfig(BaseModel):
|
||||
@property
|
||||
def trigger_events(self) -> set[str]:
|
||||
return set(self.on_events)
|
||||
|
||||
|
||||
def apply_checkpoint(instance: Any, from_checkpoint: CheckpointConfig | None) -> Any:
|
||||
"""Handle checkpoint config for a kickoff method.
|
||||
|
||||
If *from_checkpoint* carries a ``restore_from`` path, builds and returns a
|
||||
restored instance (with ``restore_from`` cleared). The caller should
|
||||
dispatch into its own kickoff variant on that restored instance.
|
||||
|
||||
If *from_checkpoint* is present but has no ``restore_from``, sets
|
||||
``instance.checkpoint`` and returns ``None`` (proceed normally).
|
||||
|
||||
If *from_checkpoint* is ``None``, returns ``None`` immediately.
|
||||
"""
|
||||
if from_checkpoint is None:
|
||||
return None
|
||||
if from_checkpoint.restore_from is not None:
|
||||
restored = type(instance).from_checkpoint(from_checkpoint)
|
||||
restored.checkpoint = from_checkpoint.model_copy(update={"restore_from": None})
|
||||
return restored
|
||||
instance.checkpoint = from_checkpoint
|
||||
return None
|
||||
|
||||
@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -102,14 +103,31 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
return None
|
||||
|
||||
|
||||
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
|
||||
def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
cfg.provider.checkpoint(data, cfg.location)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=state._parent_id,
|
||||
branch=state._branch,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
msg: str = (
|
||||
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints)
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
|
||||
|
||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||
@@ -128,7 +146,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
|
||||
if cfg is None:
|
||||
return
|
||||
try:
|
||||
_do_checkpoint(state, cfg)
|
||||
_do_checkpoint(state, cfg, event)
|
||||
except Exception:
|
||||
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)
|
||||
|
||||
|
||||
@@ -17,12 +17,21 @@ class BaseProvider(BaseModel, ABC):
|
||||
provider_type: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Persist a snapshot synchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
@@ -30,12 +39,21 @@ class BaseProvider(BaseModel, ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Persist a snapshot asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
@@ -43,12 +61,25 @@ class BaseProvider(BaseModel, ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove old checkpoints, keeping at most *max_keep*.
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove old checkpoints, keeping at most *max_keep* per branch.
|
||||
|
||||
Args:
|
||||
location: The storage destination passed to ``checkpoint``.
|
||||
max_keep: Maximum number of checkpoints to retain.
|
||||
branch: Only prune checkpoints on this branch.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a location string.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``checkpoint`` call.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -19,48 +19,87 @@ from crewai.state.provider.core import BaseProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_branch(base: str, branch: str) -> None:
|
||||
"""Validate that a branch name doesn't escape the base directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch resolves outside the base directory.
|
||||
"""
|
||||
base_resolved = str(Path(base).resolve())
|
||||
target_resolved = str((Path(base) / branch).resolve())
|
||||
if (
|
||||
not target_resolved.startswith(base_resolved + os.sep)
|
||||
and target_resolved != base_resolved
|
||||
):
|
||||
raise ValueError(f"Branch name escapes checkpoint directory: {branch!r}")
|
||||
|
||||
|
||||
class JsonProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||
|
||||
provider_type: Literal["json"] = "json"
|
||||
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
location: Base directory where checkpoints are saved.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
Encoded in the filename for queryable lineage without
|
||||
parsing the blob.
|
||||
branch: Branch label. Files are stored under ``location/branch/``.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location)
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
location: Base directory where checkpoints are saved.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
Encoded in the filename for queryable lineage without
|
||||
parsing the blob.
|
||||
branch: Branch label. Files are stored under ``location/branch/``.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location)
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as f:
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep*."""
|
||||
pattern = os.path.join(location, "*.json")
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
|
||||
_safe_branch(location, branch)
|
||||
branch_dir = os.path.join(location, branch)
|
||||
pattern = os.path.join(branch_dir, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime)
|
||||
for path in files if max_keep == 0 else files[:-max_keep]:
|
||||
try:
|
||||
@@ -68,6 +107,16 @@ class JsonProvider(BaseProvider):
|
||||
except OSError: # noqa: PERF203
|
||||
logger.debug("Failed to remove %s", path, exc_info=True)
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a file path.
|
||||
|
||||
The filename format is ``{ts}_{uuid8}_p-{parent}.json``.
|
||||
The checkpoint ID is the ``{ts}_{uuid8}`` prefix.
|
||||
"""
|
||||
stem = Path(location).stem
|
||||
idx = stem.find("_p-")
|
||||
return stem[:idx] if idx != -1 else stem
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a JSON checkpoint file.
|
||||
|
||||
@@ -92,15 +141,24 @@ class JsonProvider(BaseProvider):
|
||||
return await f.read()
|
||||
|
||||
|
||||
def _build_path(directory: str) -> Path:
|
||||
"""Build a timestamped checkpoint file path.
|
||||
def _build_path(
|
||||
directory: str, branch: str = "main", parent_id: str | None = None
|
||||
) -> Path:
|
||||
"""Build a timestamped checkpoint file path under a branch subdirectory.
|
||||
|
||||
Filename format: ``{ts}_{uuid8}_p-{parent_id}.json``
|
||||
|
||||
Args:
|
||||
directory: Parent directory for the checkpoint file.
|
||||
directory: Base directory for checkpoints.
|
||||
branch: Branch label used as a subdirectory name.
|
||||
parent_id: Parent checkpoint ID to encode in the filename.
|
||||
|
||||
Returns:
|
||||
The target file path.
|
||||
"""
|
||||
_safe_branch(directory, branch)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||
return Path(directory) / filename
|
||||
short_uuid = uuid.uuid4().hex[:8]
|
||||
parent_suffix = parent_id or "none"
|
||||
filename = f"{ts}_{short_uuid}_p-{parent_suffix}.json"
|
||||
return Path(directory) / branch / filename
|
||||
|
||||
@@ -17,15 +17,20 @@ _CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
branch TEXT NOT NULL DEFAULT 'main',
|
||||
data JSONB NOT NULL
|
||||
)
|
||||
"""
|
||||
|
||||
_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, jsonb(?))"
|
||||
_INSERT = (
|
||||
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
|
||||
"VALUES (?, ?, ?, ?, jsonb(?))"
|
||||
)
|
||||
_SELECT = "SELECT json(data) FROM checkpoints WHERE id = ?"
|
||||
_PRUNE = """
|
||||
DELETE FROM checkpoints WHERE rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
|
||||
DELETE FROM checkpoints WHERE branch = ? AND rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints WHERE branch = ? ORDER BY rowid DESC LIMIT ?
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -50,12 +55,21 @@ class SqliteProvider(BaseProvider):
|
||||
|
||||
provider_type: Literal["sqlite"] = "sqlite"
|
||||
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a checkpoint to the SQLite database.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Path to the SQLite database file.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
@@ -65,16 +79,25 @@ class SqliteProvider(BaseProvider):
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(_CREATE_TABLE)
|
||||
conn.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
conn.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
|
||||
conn.commit()
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a checkpoint to the SQLite database asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Path to the SQLite database file.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
@@ -84,16 +107,20 @@ class SqliteProvider(BaseProvider):
|
||||
async with aiosqlite.connect(location) as db:
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute(_CREATE_TABLE)
|
||||
await db.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
await db.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
|
||||
await db.commit()
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep*."""
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep* on a branch."""
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute(_PRUNE, (max_keep,))
|
||||
conn.execute(_PRUNE, (branch, branch, max_keep))
|
||||
conn.commit()
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a ``db_path#id`` string."""
|
||||
return location.rsplit("#", 1)[1]
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a checkpoint from the SQLite database.
|
||||
|
||||
|
||||
34
lib/crewai/src/crewai/state/provider/utils.py
Normal file
34
lib/crewai/src/crewai/state/provider/utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Provider detection utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
|
||||
_SQLITE_MAGIC = b"SQLite format 3\x00"
|
||||
|
||||
|
||||
def detect_provider(path: str) -> BaseProvider:
|
||||
"""Detect the storage provider from a checkpoint path.
|
||||
|
||||
Reads the file's magic bytes to determine if it's a SQLite database.
|
||||
For paths containing ``#``, checks the portion before the ``#``.
|
||||
Falls back to JsonProvider.
|
||||
|
||||
Args:
|
||||
path: A checkpoint file path, directory, or ``db_path#checkpoint_id``.
|
||||
|
||||
Returns:
|
||||
The appropriate provider instance.
|
||||
"""
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.provider.sqlite_provider import SqliteProvider
|
||||
|
||||
file_path = path.split("#")[0] if "#" in path else path
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
if f.read(16) == _SQLITE_MAGIC:
|
||||
return SqliteProvider()
|
||||
except OSError:
|
||||
pass
|
||||
return JsonProvider()
|
||||
@@ -9,8 +9,11 @@ via ``RuntimeState.model_rebuild()``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import (
|
||||
ModelWrapValidatorHandler,
|
||||
PrivateAttr,
|
||||
@@ -20,9 +23,14 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from crewai.context import capture_execution_context
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.event_record import EventRecord
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -58,12 +66,51 @@ def _sync_checkpoint_fields(entity: object) -> None:
|
||||
entity.checkpoint_inputs = entity._inputs
|
||||
entity.checkpoint_train = entity._train
|
||||
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
|
||||
for task in entity.tasks:
|
||||
task.checkpoint_original_description = task._original_description
|
||||
task.checkpoint_original_expected_output = task._original_expected_output
|
||||
|
||||
|
||||
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply version-based migrations to checkpoint data.
|
||||
|
||||
Each block handles checkpoints older than a specific version,
|
||||
transforming them forward to the current format. Blocks run in
|
||||
version order so migrations compose.
|
||||
|
||||
Args:
|
||||
data: The raw deserialized checkpoint dict.
|
||||
|
||||
Returns:
|
||||
The migrated checkpoint dict.
|
||||
"""
|
||||
raw = data.get("crewai_version")
|
||||
current = Version(get_crewai_version())
|
||||
stored = Version(raw) if raw else Version("0.0.0")
|
||||
|
||||
if raw is None:
|
||||
logger.warning("Checkpoint has no crewai_version — treating as 0.0.0")
|
||||
elif stored != current:
|
||||
logger.debug(
|
||||
"Migrating checkpoint from crewAI %s to %s",
|
||||
stored,
|
||||
current,
|
||||
)
|
||||
|
||||
# --- migrations in version order ---
|
||||
# if stored < Version("X.Y.Z"):
|
||||
# data.setdefault("some_field", "default")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
root: list[Entity]
|
||||
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
|
||||
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
|
||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
|
||||
@property
|
||||
def event_record(self) -> EventRecord:
|
||||
@@ -73,8 +120,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
@model_serializer(mode="plain")
|
||||
def _serialize(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crewai_version": get_crewai_version(),
|
||||
"parent_id": self._parent_id,
|
||||
"branch": self._branch,
|
||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||
"event_record": self._event_record.model_dump(),
|
||||
"event_record": self._event_record.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@@ -83,13 +133,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState]
|
||||
) -> RuntimeState:
|
||||
if isinstance(data, dict) and "entities" in data:
|
||||
data = _migrate(data)
|
||||
record_data = data.get("event_record")
|
||||
state = handler(data["entities"])
|
||||
if record_data:
|
||||
state._event_record = EventRecord.model_validate(record_data)
|
||||
state._parent_id = data.get("parent_id")
|
||||
state._branch = data.get("branch", "main")
|
||||
return state
|
||||
return handler(data)
|
||||
|
||||
def _chain_lineage(self, provider: BaseProvider, location: str) -> None:
|
||||
"""Update lineage fields after a successful checkpoint write.
|
||||
|
||||
Sets ``_checkpoint_id`` and ``_parent_id`` so the next write
|
||||
records the correct parent in the lineage chain.
|
||||
|
||||
Args:
|
||||
provider: The provider that performed the write.
|
||||
location: The location string returned by the provider.
|
||||
"""
|
||||
self._checkpoint_id = provider.extract_id(location)
|
||||
self._parent_id = self._checkpoint_id
|
||||
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
@@ -101,7 +167,14 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return self._provider.checkpoint(self.model_dump_json(), location)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
"""Async version of :meth:`checkpoint`.
|
||||
@@ -114,41 +187,84 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), location)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
def fork(self, branch: str | None = None) -> None:
|
||||
"""Create a new execution branch and write an initial checkpoint.
|
||||
|
||||
If this state was restored from a checkpoint, an initial checkpoint
|
||||
is written on the new branch so the fork point is recorded.
|
||||
|
||||
Args:
|
||||
branch: Branch label. Auto-generated from the current checkpoint
|
||||
ID if not provided. Always unique — safe to call multiple
|
||||
times without collisions.
|
||||
"""
|
||||
if branch:
|
||||
self._branch = branch
|
||||
elif self._checkpoint_id:
|
||||
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
|
||||
else:
|
||||
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||
) -> RuntimeState:
|
||||
def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState:
|
||||
"""Restore a RuntimeState from a checkpoint.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``checkpoint`` call.
|
||||
provider: The storage backend to read from.
|
||||
config: Checkpoint configuration with ``restore_from`` set.
|
||||
**kwargs: Passed to ``model_validate_json``.
|
||||
|
||||
Returns:
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
from crewai.state.provider.utils import detect_provider
|
||||
|
||||
if config.restore_from is None:
|
||||
raise ValueError("CheckpointConfig.restore_from must be set")
|
||||
location = str(config.restore_from)
|
||||
provider = detect_provider(location)
|
||||
raw = provider.from_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
async def afrom_checkpoint(
|
||||
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||
cls, config: CheckpointConfig, **kwargs: Any
|
||||
) -> RuntimeState:
|
||||
"""Async version of :meth:`from_checkpoint`.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``acheckpoint`` call.
|
||||
provider: The storage backend to read from.
|
||||
config: Checkpoint configuration with ``restore_from`` set.
|
||||
**kwargs: Passed to ``model_validate_json``.
|
||||
|
||||
Returns:
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
from crewai.state.provider.utils import detect_provider
|
||||
|
||||
if config.restore_from is None:
|
||||
raise ValueError("CheckpointConfig.restore_from must be set")
|
||||
location = str(config.restore_from)
|
||||
provider = detect_provider(location)
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
return state
|
||||
|
||||
|
||||
def _prepare_entities(root: list[Entity]) -> None:
|
||||
|
||||
@@ -45,6 +45,7 @@ from crewai.events.types.task_events import (
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -80,7 +81,7 @@ from crewai.utilities.guardrail_types import (
|
||||
GuardrailType,
|
||||
GuardrailsType,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
@@ -115,7 +116,6 @@ class Task(BaseModel):
|
||||
used_tools: int = 0
|
||||
tools_errors: int = 0
|
||||
delegations: int = 0
|
||||
i18n: I18N = Field(default_factory=get_i18n)
|
||||
name: str | None = Field(default=None)
|
||||
prompt_context: str | None = None
|
||||
description: str = Field(description="Description of the actual task.")
|
||||
@@ -231,6 +231,8 @@ class Task(BaseModel):
|
||||
_original_description: str | None = PrivateAttr(default=None)
|
||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||
_original_output_file: str | None = PrivateAttr(default=None)
|
||||
checkpoint_original_description: str | None = Field(default=None, exclude=False)
|
||||
checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@@ -300,12 +302,14 @@ class Task(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_required_fields(self) -> Self:
|
||||
required_fields = ["description", "expected_output"]
|
||||
for field in required_fields:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
if self.description is None:
|
||||
raise ValueError(
|
||||
"description must be provided either directly or through config"
|
||||
)
|
||||
if self.expected_output is None:
|
||||
raise ValueError(
|
||||
"expected_output must be provided either directly or through config"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -837,8 +841,8 @@ class Task(BaseModel):
|
||||
should_inject = self.allow_crewai_trigger_context
|
||||
|
||||
if should_inject and self.agent:
|
||||
crew = getattr(self.agent, "crew", None)
|
||||
if crew and hasattr(crew, "_inputs") and crew._inputs:
|
||||
crew = self.agent.crew
|
||||
if crew and not isinstance(crew, str) and crew._inputs:
|
||||
trigger_payload = crew._inputs.get("crewai_trigger_payload")
|
||||
if trigger_payload is not None:
|
||||
description += f"\n\nTrigger Payload: {trigger_payload}"
|
||||
@@ -851,11 +855,12 @@ class Task(BaseModel):
|
||||
isinstance(self.agent.llm, BaseLLM)
|
||||
and self.agent.llm.supports_multimodal()
|
||||
):
|
||||
provider: str = str(
|
||||
getattr(self.agent.llm, "provider", None)
|
||||
or getattr(self.agent.llm, "model", "openai")
|
||||
provider: str = self.agent.llm.provider or self.agent.llm.model
|
||||
api: str | None = (
|
||||
self.agent.llm.api
|
||||
if isinstance(self.agent.llm, OpenAICompletion)
|
||||
else None
|
||||
)
|
||||
api: str | None = getattr(self.agent.llm, "api", None)
|
||||
supported_types = get_supported_content_types(provider, api)
|
||||
|
||||
def is_auto_injected(content_type: str) -> bool:
|
||||
@@ -896,7 +901,7 @@ class Task(BaseModel):
|
||||
|
||||
tasks_slices = [description]
|
||||
|
||||
output = self.i18n.slice("expected_output").format(
|
||||
output = I18N_DEFAULT.slice("expected_output").format(
|
||||
expected_output=self.expected_output
|
||||
)
|
||||
tasks_slices = [description, output]
|
||||
@@ -968,7 +973,7 @@ Follow these guidelines:
|
||||
raise ValueError(f"Error interpolating output_file path: {e!s}") from e
|
||||
|
||||
if inputs.get("crew_chat_messages"):
|
||||
conversation_instruction = self.i18n.slice(
|
||||
conversation_instruction = I18N_DEFAULT.slice(
|
||||
"conversation_history_instruction"
|
||||
)
|
||||
|
||||
@@ -1219,7 +1224,7 @@ Follow these guidelines:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
context = I18N_DEFAULT.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
@@ -1316,7 +1321,7 @@ Follow these guidelines:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
context = I18N_DEFAULT.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
|
||||
@@ -53,6 +53,7 @@ from crewai.telemetry.utils import (
|
||||
close_span,
|
||||
crew_memory_span_attribute_value,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
@@ -319,7 +320,7 @@ class Telemetry:
|
||||
"verbose?": agent.verbose,
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"i18n": agent.i18n.prompt_file,
|
||||
"i18n": I18N_DEFAULT.prompt_file,
|
||||
"function_calling_llm": (
|
||||
getattr(
|
||||
getattr(agent, "function_calling_llm", None),
|
||||
@@ -849,7 +850,7 @@ class Telemetry:
|
||||
"verbose?": agent.verbose,
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"i18n": agent.i18n.prompt_file,
|
||||
"i18n": I18N_DEFAULT.prompt_file,
|
||||
"llm": agent.llm.model
|
||||
if isinstance(agent.llm, BaseLLM)
|
||||
else str(agent.llm),
|
||||
@@ -1062,3 +1063,20 @@ class Telemetry:
|
||||
close_span(span)
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
|
||||
def template_installed_span(self, template_name: str) -> None:
|
||||
"""Records when a template is downloaded and installed.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template that was installed
|
||||
(without the template_ prefix).
|
||||
"""
|
||||
|
||||
def _operation() -> None:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Template Installed")
|
||||
self._add_attribute(span, "crewai_version", version("crewai"))
|
||||
self._add_attribute(span, "template_name", template_name)
|
||||
close_span(span)
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
|
||||
@@ -3,10 +3,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import I18N
|
||||
|
||||
|
||||
i18n = I18N()
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
class AddImageToolSchema(BaseModel):
|
||||
@@ -19,9 +16,9 @@ class AddImageToolSchema(BaseModel):
|
||||
class AddImageTool(BaseTool):
|
||||
"""Tool for adding images to the content"""
|
||||
|
||||
name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore[index]
|
||||
name: str = Field(default_factory=lambda: I18N_DEFAULT.tools("add_image")["name"]) # type: ignore[index]
|
||||
description: str = Field(
|
||||
default_factory=lambda: i18n.tools("add_image")["description"] # type: ignore[index]
|
||||
default_factory=lambda: I18N_DEFAULT.tools("add_image")["description"] # type: ignore[index]
|
||||
)
|
||||
args_schema: type[BaseModel] = AddImageToolSchema
|
||||
|
||||
@@ -31,7 +28,7 @@ class AddImageTool(BaseTool):
|
||||
action: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
||||
action = action or I18N_DEFAULT.tools("add_image")["default_action"] # type: ignore
|
||||
content = [
|
||||
{"type": "text", "text": action},
|
||||
{
|
||||
|
||||
@@ -5,21 +5,19 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.tools.agent_tools.ask_question_tool import AskQuestionTool
|
||||
from crewai.tools.agent_tools.delegate_work_tool import DelegateWorkTool
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class AgentTools:
|
||||
"""Manager class for agent-related tools"""
|
||||
|
||||
def __init__(self, agents: Sequence[BaseAgent], i18n: I18N | None = None) -> None:
|
||||
def __init__(self, agents: Sequence[BaseAgent]) -> None:
|
||||
self.agents = agents
|
||||
self.i18n = i18n if i18n is not None else get_i18n()
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Get all available agent tools"""
|
||||
@@ -27,14 +25,12 @@ class AgentTools:
|
||||
|
||||
delegate_tool = DelegateWorkTool(
|
||||
agents=self.agents,
|
||||
i18n=self.i18n,
|
||||
description=self.i18n.tools("delegate_work").format(coworkers=coworkers), # type: ignore
|
||||
description=I18N_DEFAULT.tools("delegate_work").format(coworkers=coworkers), # type: ignore
|
||||
)
|
||||
|
||||
ask_tool = AskQuestionTool(
|
||||
agents=self.agents,
|
||||
i18n=self.i18n,
|
||||
description=self.i18n.tools("ask_question").format(coworkers=coworkers), # type: ignore
|
||||
description=I18N_DEFAULT.tools("ask_question").format(coworkers=coworkers), # type: ignore
|
||||
)
|
||||
|
||||
return [delegate_tool, ask_tool]
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import Field
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,9 +16,6 @@ class BaseAgentTool(BaseTool):
|
||||
"""Base class for agent-related tools"""
|
||||
|
||||
agents: list[BaseAgent] = Field(description="List of available agents")
|
||||
i18n: I18N = Field(
|
||||
default_factory=get_i18n, description="Internationalization settings"
|
||||
)
|
||||
|
||||
def sanitize_agent_name(self, name: str) -> str:
|
||||
"""
|
||||
@@ -93,7 +90,7 @@ class BaseAgentTool(BaseTool):
|
||||
)
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
return I18N_DEFAULT.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[
|
||||
f"- {self.sanitize_agent_name(agent.role)}"
|
||||
@@ -105,7 +102,7 @@ class BaseAgentTool(BaseTool):
|
||||
|
||||
if not agent:
|
||||
# No matching agent found after sanitization
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
return I18N_DEFAULT.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[
|
||||
f"- {self.sanitize_agent_name(agent.role)}"
|
||||
@@ -120,8 +117,7 @@ class BaseAgentTool(BaseTool):
|
||||
task_with_assigned_agent = Task(
|
||||
description=task,
|
||||
agent=selected_agent,
|
||||
expected_output=selected_agent.i18n.slice("manager_request"),
|
||||
i18n=selected_agent.i18n,
|
||||
expected_output=I18N_DEFAULT.slice("manager_request"),
|
||||
)
|
||||
logger.debug(
|
||||
f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}"
|
||||
@@ -129,6 +125,6 @@ class BaseAgentTool(BaseTool):
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return self.i18n.errors("agent_tool_execution_error").format(
|
||||
return I18N_DEFAULT.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(selected_agent.role), error=str(e)
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
class RecallMemorySchema(BaseModel):
|
||||
@@ -114,18 +114,17 @@ def create_memory_tools(memory: Any) -> list[BaseTool]:
|
||||
Returns:
|
||||
List containing a RecallMemoryTool and, if not read-only, a RememberTool.
|
||||
"""
|
||||
i18n = get_i18n()
|
||||
tools: list[BaseTool] = [
|
||||
RecallMemoryTool(
|
||||
memory=memory,
|
||||
description=i18n.tools("recall_memory"),
|
||||
description=I18N_DEFAULT.tools("recall_memory"),
|
||||
),
|
||||
]
|
||||
if not memory.read_only:
|
||||
tools.append(
|
||||
RememberTool(
|
||||
memory=memory,
|
||||
description=i18n.tools("save_to_memory"),
|
||||
description=I18N_DEFAULT.tools("save_to_memory"),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
@@ -28,7 +28,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
@@ -93,7 +93,6 @@ class ToolUsage:
|
||||
action: Any = None,
|
||||
fingerprint_context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self._i18n: I18N = agent.i18n if agent else get_i18n()
|
||||
self._telemetry: Telemetry = Telemetry()
|
||||
self._run_attempts: int = 1
|
||||
self._max_parsing_attempts: int = 3
|
||||
@@ -146,7 +145,7 @@ class ToolUsage:
|
||||
if (
|
||||
isinstance(tool, CrewStructuredTool)
|
||||
and sanitize_tool_name(tool.name)
|
||||
== sanitize_tool_name(self._i18n.tools("add_image")["name"]) # type: ignore
|
||||
== sanitize_tool_name(I18N_DEFAULT.tools("add_image")["name"]) # type: ignore
|
||||
):
|
||||
try:
|
||||
return self._use(tool_string=tool_string, tool=tool, calling=calling)
|
||||
@@ -194,7 +193,7 @@ class ToolUsage:
|
||||
if (
|
||||
isinstance(tool, CrewStructuredTool)
|
||||
and sanitize_tool_name(tool.name)
|
||||
== sanitize_tool_name(self._i18n.tools("add_image")["name"]) # type: ignore
|
||||
== sanitize_tool_name(I18N_DEFAULT.tools("add_image")["name"]) # type: ignore
|
||||
):
|
||||
try:
|
||||
return await self._ause(
|
||||
@@ -230,7 +229,7 @@ class ToolUsage:
|
||||
"""
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
result = I18N_DEFAULT.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self._telemetry.tool_repeated_usage(
|
||||
@@ -415,7 +414,7 @@ class ToolUsage:
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
error_message = self._i18n.errors(
|
||||
error_message = I18N_DEFAULT.errors(
|
||||
"tool_usage_exception"
|
||||
).format(
|
||||
error=e,
|
||||
@@ -423,7 +422,7 @@ class ToolUsage:
|
||||
tool_inputs=tool.description,
|
||||
)
|
||||
result = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
@@ -461,7 +460,7 @@ class ToolUsage:
|
||||
# Repeated usage check happens before event emission - safe to return early
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
result = I18N_DEFAULT.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self._telemetry.tool_repeated_usage(
|
||||
@@ -648,7 +647,7 @@ class ToolUsage:
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
error_message = self._i18n.errors(
|
||||
error_message = I18N_DEFAULT.errors(
|
||||
"tool_usage_exception"
|
||||
).format(
|
||||
error=e,
|
||||
@@ -656,7 +655,7 @@ class ToolUsage:
|
||||
tool_inputs=tool.description,
|
||||
)
|
||||
result = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
@@ -699,7 +698,7 @@ class ToolUsage:
|
||||
|
||||
def _remember_format(self, result: str) -> str:
|
||||
result = str(result)
|
||||
result += "\n\n" + self._i18n.slice("tools").format(
|
||||
result += "\n\n" + I18N_DEFAULT.slice("tools").format(
|
||||
tools=self.tools_description, tool_names=self.tools_names
|
||||
)
|
||||
return result
|
||||
@@ -825,12 +824,12 @@ class ToolUsage:
|
||||
except Exception:
|
||||
if raise_error:
|
||||
raise
|
||||
return ToolUsageError(f"{self._i18n.errors('tool_arguments_error')}")
|
||||
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
if raise_error:
|
||||
raise
|
||||
return ToolUsageError(f"{self._i18n.errors('tool_arguments_error')}")
|
||||
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
|
||||
|
||||
return ToolCalling(
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
@@ -856,7 +855,7 @@ class ToolUsage:
|
||||
if self.agent and self.agent.verbose:
|
||||
PRINTER.print(content=f"\n\n{e}\n", color="red")
|
||||
return ToolUsageError(
|
||||
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
f"{I18N_DEFAULT.errors('tool_usage_error').format(error=e)}\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}"
|
||||
)
|
||||
return self._tool_calling(tool_string)
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -78,12 +79,21 @@ class StreamingOutputBase(Generic[T]):
|
||||
via the .result property after streaming completes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize streaming output base."""
|
||||
self._result: T | None = None
|
||||
self._completed: bool = False
|
||||
self._chunks: list[StreamChunk] = []
|
||||
self._error: Exception | None = None
|
||||
self._cancelled: bool = False
|
||||
self._exhausted: bool = False
|
||||
self._on_cleanup: Callable[[], None] | None = None
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
@property
|
||||
def result(self) -> T:
|
||||
@@ -112,6 +122,11 @@ class StreamingOutputBase(Generic[T]):
|
||||
"""Check if streaming has completed."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if streaming was cancelled."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[StreamChunk]:
|
||||
"""Get all collected chunks so far."""
|
||||
@@ -129,6 +144,98 @@ class StreamingOutputBase(Generic[T]):
|
||||
if chunk.chunk_type == StreamChunkType.TEXT
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info: Any) -> None:
|
||||
"""Exit async context manager, cancelling if still running."""
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Cancel streaming and clean up resources.
|
||||
|
||||
Cancels any in-flight tasks and closes the underlying async iterator.
|
||||
Safe to call multiple times. No-op if already cancelled or fully consumed.
|
||||
"""
|
||||
if self._cancelled or self._exhausted or self._error is not None:
|
||||
return
|
||||
self._cancelled = True
|
||||
self._completed = True
|
||||
if self._async_iterator is not None and hasattr(self._async_iterator, "aclose"):
|
||||
await self._async_iterator.aclose()
|
||||
if self._on_cleanup is not None:
|
||||
self._on_cleanup()
|
||||
self._on_cleanup = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cancel streaming and clean up resources (sync).
|
||||
|
||||
Closes the underlying sync iterator. Safe to call multiple times.
|
||||
No-op if already cancelled, fully consumed, or errored.
|
||||
"""
|
||||
if self._cancelled or self._exhausted or self._error is not None:
|
||||
return
|
||||
self._cancelled = True
|
||||
self._completed = True
|
||||
if self._sync_iterator is not None and hasattr(self._sync_iterator, "close"):
|
||||
self._sync_iterator.close()
|
||||
if self._on_cleanup is not None:
|
||||
self._on_cleanup()
|
||||
self._on_cleanup = None
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
self._exhausted = True
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
self._exhausted = True
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
|
||||
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
"""Streaming output wrapper for crew execution.
|
||||
@@ -167,9 +274,7 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
super().__init__(sync_iterator=sync_iterator, async_iterator=async_iterator)
|
||||
self._results: list[CrewOutput] | None = None
|
||||
|
||||
@property
|
||||
@@ -204,56 +309,6 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
self._results = results
|
||||
self._completed = True
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: CrewOutput) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
@@ -286,71 +341,6 @@ class FlowStreamingOutput(StreamingOutputBase[Any]):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize flow streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: Any) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
|
||||
@@ -29,6 +29,14 @@ class UsageMetrics(BaseModel):
|
||||
completion_tokens: int = Field(
|
||||
default=0, description="Number of tokens used in completions."
|
||||
)
|
||||
reasoning_tokens: int = Field(
|
||||
default=0,
|
||||
description="Number of reasoning/thinking tokens (e.g. OpenAI o-series, Gemini thinking).",
|
||||
)
|
||||
cache_creation_tokens: int = Field(
|
||||
default=0,
|
||||
description="Number of cache creation tokens (e.g. Anthropic cache writes).",
|
||||
)
|
||||
successful_requests: int = Field(
|
||||
default=0, description="Number of successful requests made."
|
||||
)
|
||||
@@ -43,4 +51,6 @@ class UsageMetrics(BaseModel):
|
||||
self.prompt_tokens += usage_metrics.prompt_tokens
|
||||
self.cached_prompt_tokens += usage_metrics.cached_prompt_tokens
|
||||
self.completion_tokens += usage_metrics.completion_tokens
|
||||
self.reasoning_tokens += usage_metrics.reasoning_tokens
|
||||
self.cache_creation_tokens += usage_metrics.cache_creation_tokens
|
||||
self.successful_requests += usage_metrics.successful_requests
|
||||
|
||||
@@ -31,7 +31,7 @@ from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.printer import PRINTER, ColoredText, Printer
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
@@ -254,7 +254,6 @@ def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
|
||||
def handle_max_iterations_exceeded(
|
||||
formatted_answer: AgentAction | AgentFinish | None,
|
||||
printer: Printer,
|
||||
i18n: I18N,
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[TokenCalcHandler],
|
||||
@@ -265,7 +264,6 @@ def handle_max_iterations_exceeded(
|
||||
Args:
|
||||
formatted_answer: The last formatted answer from the agent.
|
||||
printer: Printer instance for output.
|
||||
i18n: I18N instance for internationalization.
|
||||
messages: List of messages to send to the LLM.
|
||||
llm: The LLM instance to call.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
@@ -282,10 +280,10 @@ def handle_max_iterations_exceeded(
|
||||
|
||||
if formatted_answer and hasattr(formatted_answer, "text"):
|
||||
assistant_message = (
|
||||
formatted_answer.text + f"\n{i18n.errors('force_final_answer')}"
|
||||
formatted_answer.text + f"\n{I18N_DEFAULT.errors('force_final_answer')}"
|
||||
)
|
||||
else:
|
||||
assistant_message = i18n.errors("force_final_answer")
|
||||
assistant_message = I18N_DEFAULT.errors("force_final_answer")
|
||||
|
||||
messages.append(format_message_for_llm(assistant_message, role="assistant"))
|
||||
|
||||
@@ -687,7 +685,6 @@ def handle_context_length(
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[TokenCalcHandler],
|
||||
i18n: I18N,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""Handle context length exceeded by either summarizing or raising an error.
|
||||
@@ -698,7 +695,6 @@ def handle_context_length(
|
||||
messages: List of messages to summarize
|
||||
llm: LLM instance for summarization
|
||||
callbacks: List of callbacks for LLM
|
||||
i18n: I18N instance for messages
|
||||
|
||||
Raises:
|
||||
SystemExit: If context length is exceeded and user opts not to summarize
|
||||
@@ -710,7 +706,7 @@ def handle_context_length(
|
||||
color="yellow",
|
||||
)
|
||||
summarize_messages(
|
||||
messages=messages, llm=llm, callbacks=callbacks, i18n=i18n, verbose=verbose
|
||||
messages=messages, llm=llm, callbacks=callbacks, verbose=verbose
|
||||
)
|
||||
else:
|
||||
if verbose:
|
||||
@@ -863,7 +859,6 @@ async def _asummarize_chunks(
|
||||
chunks: list[list[LLMMessage]],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[TokenCalcHandler],
|
||||
i18n: I18N,
|
||||
) -> list[SummaryContent]:
|
||||
"""Summarize multiple message chunks concurrently using asyncio.
|
||||
|
||||
@@ -871,7 +866,6 @@ async def _asummarize_chunks(
|
||||
chunks: List of message chunks to summarize.
|
||||
llm: LLM instance (must support ``acall``).
|
||||
callbacks: List of callbacks for the LLM.
|
||||
i18n: I18N instance for prompt templates.
|
||||
|
||||
Returns:
|
||||
Ordered list of summary contents, one per chunk.
|
||||
@@ -881,10 +875,10 @@ async def _asummarize_chunks(
|
||||
conversation_text = _format_messages_for_summary(chunk)
|
||||
summarization_messages = [
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarizer_system_message"), role="system"
|
||||
I18N_DEFAULT.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarize_instruction").format(
|
||||
I18N_DEFAULT.slice("summarize_instruction").format(
|
||||
conversation=conversation_text
|
||||
),
|
||||
),
|
||||
@@ -901,7 +895,6 @@ def summarize_messages(
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[TokenCalcHandler],
|
||||
i18n: I18N,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""Summarize messages to fit within context window.
|
||||
@@ -917,7 +910,6 @@ def summarize_messages(
|
||||
messages: List of messages to summarize (modified in-place)
|
||||
llm: LLM instance for summarization
|
||||
callbacks: List of callbacks for LLM
|
||||
i18n: I18N instance for messages
|
||||
verbose: Whether to print progress.
|
||||
"""
|
||||
# 1. Extract & preserve file attachments from user messages
|
||||
@@ -953,10 +945,10 @@ def summarize_messages(
|
||||
conversation_text = _format_messages_for_summary(chunk)
|
||||
summarization_messages = [
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarizer_system_message"), role="system"
|
||||
I18N_DEFAULT.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarize_instruction").format(
|
||||
I18N_DEFAULT.slice("summarize_instruction").format(
|
||||
conversation=conversation_text
|
||||
),
|
||||
),
|
||||
@@ -971,9 +963,7 @@ def summarize_messages(
|
||||
content=f"Summarizing {total_chunks} chunks in parallel...",
|
||||
color="yellow",
|
||||
)
|
||||
coro = _asummarize_chunks(
|
||||
chunks=chunks, llm=llm, callbacks=callbacks, i18n=i18n
|
||||
)
|
||||
coro = _asummarize_chunks(chunks=chunks, llm=llm, callbacks=callbacks)
|
||||
if is_inside_event_loop():
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
@@ -988,7 +978,7 @@ def summarize_messages(
|
||||
messages.extend(system_messages)
|
||||
|
||||
summary_message = format_message_for_llm(
|
||||
i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
I18N_DEFAULT.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
if preserved_files:
|
||||
summary_message["files"] = preserved_files
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
|
||||
_I18N = get_i18n()
|
||||
_I18N = I18N_DEFAULT
|
||||
|
||||
|
||||
class ConverterError(Exception):
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.training_converter import TrainingConverter
|
||||
|
||||
@@ -98,11 +98,9 @@ class TaskEvaluator:
|
||||
|
||||
if not self.llm.supports_function_calling(): # type: ignore[union-attr]
|
||||
schema_dict = generate_model_description(TaskEvaluation)
|
||||
output_schema: str = (
|
||||
get_i18n()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=json.dumps(schema_dict, indent=2))
|
||||
)
|
||||
output_schema: str = I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=json.dumps(schema_dict, indent=2))
|
||||
instructions = f"{instructions}\n\n{output_schema}"
|
||||
|
||||
converter = Converter(
|
||||
@@ -174,11 +172,9 @@ class TaskEvaluator:
|
||||
|
||||
if not self.llm.supports_function_calling(): # type: ignore[union-attr]
|
||||
schema_dict = generate_model_description(TrainingTaskEvaluation)
|
||||
output_schema: str = (
|
||||
get_i18n()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=json.dumps(schema_dict, indent=2))
|
||||
)
|
||||
output_schema: str = I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=json.dumps(schema_dict, indent=2))
|
||||
instructions = f"{instructions}\n\n{output_schema}"
|
||||
|
||||
converter = TrainingConverter(
|
||||
|
||||
@@ -142,3 +142,6 @@ def get_i18n(prompt_file: str | None = None) -> I18N:
|
||||
Cached I18N instance.
|
||||
"""
|
||||
return I18N(prompt_file=prompt_file)
|
||||
|
||||
|
||||
I18N_DEFAULT: I18N = get_i18n()
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
|
||||
class StandardPromptResult(BaseModel):
|
||||
@@ -49,7 +49,6 @@ class Prompts(BaseModel):
|
||||
- Need to refactor so that prompt is not tightly coupled to agent.
|
||||
"""
|
||||
|
||||
i18n: I18N = Field(default_factory=get_i18n)
|
||||
has_tools: bool = Field(
|
||||
default=False, description="Indicates if the agent has access to tools"
|
||||
)
|
||||
@@ -140,13 +139,13 @@ class Prompts(BaseModel):
|
||||
if not system_template or not prompt_template:
|
||||
# If any of the required templates are missing, fall back to the default format
|
||||
prompt_parts: list[str] = [
|
||||
self.i18n.slice(component) for component in components
|
||||
I18N_DEFAULT.slice(component) for component in components
|
||||
]
|
||||
prompt = "".join(prompt_parts)
|
||||
else:
|
||||
# All templates are provided, use them
|
||||
template_parts: list[str] = [
|
||||
self.i18n.slice(component)
|
||||
I18N_DEFAULT.slice(component)
|
||||
for component in components
|
||||
if component != "task"
|
||||
]
|
||||
@@ -154,7 +153,7 @@ class Prompts(BaseModel):
|
||||
"{{ .System }}", "".join(template_parts)
|
||||
)
|
||||
prompt = prompt_template.replace(
|
||||
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
||||
"{{ .Prompt }}", "".join(I18N_DEFAULT.slice("task"))
|
||||
)
|
||||
# Handle missing response_template
|
||||
if response_template:
|
||||
|
||||
@@ -19,7 +19,18 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Final,
|
||||
ForwardRef,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
|
||||
import jsonref # type: ignore[import-untyped]
|
||||
@@ -99,15 +110,26 @@ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
defs = schema.get("$defs", {})
|
||||
schema_copy = deepcopy(schema)
|
||||
expanding: set[str] = set()
|
||||
|
||||
def _resolve(node: Any) -> Any:
|
||||
if isinstance(node, dict):
|
||||
ref = node.get("$ref")
|
||||
if isinstance(ref, str) and ref.startswith("#/$defs/"):
|
||||
def_name = ref.replace("#/$defs/", "")
|
||||
if def_name in defs:
|
||||
if def_name not in defs:
|
||||
raise KeyError(f"Definition '{def_name}' not found in $defs.")
|
||||
if def_name in expanding:
|
||||
def_schema = defs[def_name]
|
||||
stub: dict[str, Any] = {"type": def_schema.get("type", "object")}
|
||||
if "description" in def_schema:
|
||||
stub["description"] = def_schema["description"]
|
||||
return stub
|
||||
expanding.add(def_name)
|
||||
try:
|
||||
return _resolve(deepcopy(defs[def_name]))
|
||||
raise KeyError(f"Definition '{def_name}' not found in $defs.")
|
||||
finally:
|
||||
expanding.discard(def_name)
|
||||
return {k: _resolve(v) for k, v in node.items()}
|
||||
|
||||
if isinstance(node, list):
|
||||
@@ -119,7 +141,11 @@ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
def add_key_in_dict_recursively(
|
||||
d: dict[str, Any], key: str, value: Any, criteria: Callable[[dict[str, Any]], bool]
|
||||
d: dict[str, Any],
|
||||
key: str,
|
||||
value: Any,
|
||||
criteria: Callable[[dict[str, Any]], bool],
|
||||
_seen: set[int] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively adds a key/value pair to all nested dicts matching `criteria`.
|
||||
|
||||
@@ -128,22 +154,31 @@ def add_key_in_dict_recursively(
|
||||
key: The key to add.
|
||||
value: The value to add.
|
||||
criteria: A function that returns True for dicts that should receive the key.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
The modified dictionary.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(d, dict):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
if criteria(d) and key not in d:
|
||||
d[key] = value
|
||||
for v in d.values():
|
||||
add_key_in_dict_recursively(v, key, value, criteria)
|
||||
add_key_in_dict_recursively(v, key, value, criteria, _seen)
|
||||
elif isinstance(d, list):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for i in d:
|
||||
add_key_in_dict_recursively(i, key, value, criteria)
|
||||
add_key_in_dict_recursively(i, key, value, criteria, _seen)
|
||||
return d
|
||||
|
||||
|
||||
def force_additional_properties_false(d: Any) -> Any:
|
||||
def force_additional_properties_false(d: Any, _seen: set[int] | None = None) -> Any:
|
||||
"""Force additionalProperties=false on all object-type dicts recursively.
|
||||
|
||||
OpenAI strict mode requires all objects to have additionalProperties=false.
|
||||
@@ -154,11 +189,17 @@ def force_additional_properties_false(d: Any) -> Any:
|
||||
|
||||
Args:
|
||||
d: The dictionary/list to modify.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
The modified dictionary/list.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(d, dict):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
if d.get("type") == "object":
|
||||
d["additionalProperties"] = False
|
||||
if "properties" not in d:
|
||||
@@ -166,10 +207,13 @@ def force_additional_properties_false(d: Any) -> Any:
|
||||
if "required" not in d:
|
||||
d["required"] = []
|
||||
for v in d.values():
|
||||
force_additional_properties_false(v)
|
||||
force_additional_properties_false(v, _seen)
|
||||
elif isinstance(d, list):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for i in d:
|
||||
force_additional_properties_false(i)
|
||||
force_additional_properties_false(i, _seen)
|
||||
return d
|
||||
|
||||
|
||||
@@ -183,7 +227,7 @@ OPENAI_SUPPORTED_FORMATS: Final[
|
||||
}
|
||||
|
||||
|
||||
def strip_unsupported_formats(d: Any) -> Any:
|
||||
def strip_unsupported_formats(d: Any, _seen: set[int] | None = None) -> Any:
|
||||
"""Remove format annotations that OpenAI strict mode doesn't support.
|
||||
|
||||
OpenAI only supports: date-time, date, time, duration.
|
||||
@@ -191,11 +235,17 @@ def strip_unsupported_formats(d: Any) -> Any:
|
||||
|
||||
Args:
|
||||
d: The dictionary/list to modify.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
The modified dictionary/list.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(d, dict):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
format_value = d.get("format")
|
||||
if (
|
||||
isinstance(format_value, str)
|
||||
@@ -203,14 +253,17 @@ def strip_unsupported_formats(d: Any) -> Any:
|
||||
):
|
||||
del d["format"]
|
||||
for v in d.values():
|
||||
strip_unsupported_formats(v)
|
||||
strip_unsupported_formats(v, _seen)
|
||||
elif isinstance(d, list):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for i in d:
|
||||
strip_unsupported_formats(i)
|
||||
strip_unsupported_formats(i, _seen)
|
||||
return d
|
||||
|
||||
|
||||
def ensure_type_in_schemas(d: Any) -> Any:
|
||||
def ensure_type_in_schemas(d: Any, _seen: set[int] | None = None) -> Any:
|
||||
"""Ensure all schema objects in anyOf/oneOf have a 'type' key.
|
||||
|
||||
OpenAI strict mode requires every schema to have a 'type' key.
|
||||
@@ -218,11 +271,17 @@ def ensure_type_in_schemas(d: Any) -> Any:
|
||||
|
||||
Args:
|
||||
d: The dictionary/list to modify.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
The modified dictionary/list.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(d, dict):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for key in ("anyOf", "oneOf"):
|
||||
if key in d:
|
||||
schema_list = d[key]
|
||||
@@ -230,12 +289,15 @@ def ensure_type_in_schemas(d: Any) -> Any:
|
||||
if isinstance(schema, dict) and schema == {}:
|
||||
schema_list[i] = {"type": "object"}
|
||||
else:
|
||||
ensure_type_in_schemas(schema)
|
||||
ensure_type_in_schemas(schema, _seen)
|
||||
for v in d.values():
|
||||
ensure_type_in_schemas(v)
|
||||
ensure_type_in_schemas(v, _seen)
|
||||
elif isinstance(d, list):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for item in d:
|
||||
ensure_type_in_schemas(item)
|
||||
ensure_type_in_schemas(item, _seen)
|
||||
return d
|
||||
|
||||
|
||||
@@ -318,7 +380,9 @@ def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
return _process_oneof(deepcopy(schema))
|
||||
|
||||
|
||||
def convert_oneof_to_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
def convert_oneof_to_anyof(
|
||||
schema: dict[str, Any], _seen: set[int] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Convert oneOf to anyOf for OpenAI compatibility.
|
||||
|
||||
OpenAI's Structured Outputs support anyOf better than oneOf.
|
||||
@@ -326,26 +390,37 @@ def convert_oneof_to_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
Args:
|
||||
schema: JSON schema dictionary.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
Modified schema with anyOf instead of oneOf.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(schema, dict):
|
||||
if id(schema) in _seen:
|
||||
return schema
|
||||
_seen.add(id(schema))
|
||||
if "oneOf" in schema:
|
||||
schema["anyOf"] = schema.pop("oneOf")
|
||||
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_oneof_to_anyof(value)
|
||||
convert_oneof_to_anyof(value, _seen)
|
||||
elif isinstance(value, list):
|
||||
if id(value) in _seen:
|
||||
continue
|
||||
_seen.add(id(value))
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_oneof_to_anyof(item)
|
||||
convert_oneof_to_anyof(item, _seen)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
def ensure_all_properties_required(
|
||||
schema: dict[str, Any], _seen: set[int] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Ensure all properties are in the required array for OpenAI strict mode.
|
||||
|
||||
OpenAI's strict structured outputs require all properties to be listed
|
||||
@@ -354,11 +429,17 @@ def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
Args:
|
||||
schema: JSON schema dictionary.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
Modified schema with all properties marked as required.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(schema, dict):
|
||||
if id(schema) in _seen:
|
||||
return schema
|
||||
_seen.add(id(schema))
|
||||
if schema.get("type") == "object" and "properties" in schema:
|
||||
properties = schema["properties"]
|
||||
if properties:
|
||||
@@ -366,16 +447,21 @@ def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
ensure_all_properties_required(value)
|
||||
ensure_all_properties_required(value, _seen)
|
||||
elif isinstance(value, list):
|
||||
if id(value) in _seen:
|
||||
continue
|
||||
_seen.add(id(value))
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
ensure_all_properties_required(item)
|
||||
ensure_all_properties_required(item, _seen)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
def strip_null_from_types(
|
||||
schema: dict[str, Any], _seen: set[int] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Remove null type from anyOf/type arrays.
|
||||
|
||||
Pydantic generates `T | None` for optional fields, which creates schemas with
|
||||
@@ -384,11 +470,17 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
Args:
|
||||
schema: JSON schema dictionary.
|
||||
_seen: Internal set of visited ``id()``s, used to guard cyclic schemas.
|
||||
|
||||
Returns:
|
||||
Modified schema with null types removed.
|
||||
"""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(schema, dict):
|
||||
if id(schema) in _seen:
|
||||
return schema
|
||||
_seen.add(id(schema))
|
||||
if "anyOf" in schema:
|
||||
any_of = schema["anyOf"]
|
||||
non_null = [opt for opt in any_of if opt.get("type") != "null"]
|
||||
@@ -408,15 +500,141 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
strip_null_from_types(value)
|
||||
strip_null_from_types(value, _seen)
|
||||
elif isinstance(value, list):
|
||||
if id(value) in _seen:
|
||||
continue
|
||||
_seen.add(id(value))
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
strip_null_from_types(item)
|
||||
strip_null_from_types(item, _seen)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
_STRICT_METADATA_KEYS: Final[tuple[str, ...]] = (
|
||||
"title",
|
||||
"default",
|
||||
"examples",
|
||||
"example",
|
||||
"$comment",
|
||||
"readOnly",
|
||||
"writeOnly",
|
||||
"deprecated",
|
||||
)
|
||||
|
||||
_CLAUDE_STRICT_UNSUPPORTED: Final[tuple[str, ...]] = (
|
||||
"minimum",
|
||||
"maximum",
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
"multipleOf",
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"pattern",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
"minContains",
|
||||
"maxContains",
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
"patternProperties",
|
||||
"propertyNames",
|
||||
"dependentRequired",
|
||||
"dependentSchemas",
|
||||
)
|
||||
|
||||
|
||||
def _strip_keys_recursive(
|
||||
d: Any, keys: tuple[str, ...], _seen: set[int] | None = None
|
||||
) -> Any:
|
||||
"""Recursively delete a fixed set of keys from a schema."""
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
if isinstance(d, dict):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for key in keys:
|
||||
d.pop(key, None)
|
||||
for v in d.values():
|
||||
_strip_keys_recursive(v, keys, _seen)
|
||||
elif isinstance(d, list):
|
||||
if id(d) in _seen:
|
||||
return d
|
||||
_seen.add(id(d))
|
||||
for i in d:
|
||||
_strip_keys_recursive(i, keys, _seen)
|
||||
return d
|
||||
|
||||
|
||||
def lift_top_level_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Unwrap a top-level anyOf/oneOf/allOf wrapping a single object variant.
|
||||
|
||||
Anthropic's strict ``input_schema`` rejects top-level union keywords. When
|
||||
exactly one variant is an object schema, lift it so the root is a plain
|
||||
object; otherwise leave the schema alone.
|
||||
"""
|
||||
for key in ("anyOf", "oneOf", "allOf"):
|
||||
variants = schema.get(key)
|
||||
if not isinstance(variants, list):
|
||||
continue
|
||||
object_variants = [
|
||||
v for v in variants if isinstance(v, dict) and v.get("type") == "object"
|
||||
]
|
||||
if len(object_variants) == 1:
|
||||
lifted = deepcopy(object_variants[0])
|
||||
schema.pop(key)
|
||||
schema.update(lifted)
|
||||
break
|
||||
return schema
|
||||
|
||||
|
||||
def _common_strict_pipeline(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shared strict sanitization: inline refs, close objects, require all properties."""
|
||||
sanitized = resolve_refs(deepcopy(params))
|
||||
sanitized.pop("$defs", None)
|
||||
sanitized = convert_oneof_to_anyof(sanitized)
|
||||
sanitized = ensure_type_in_schemas(sanitized)
|
||||
sanitized = force_additional_properties_false(sanitized)
|
||||
sanitized = ensure_all_properties_required(sanitized)
|
||||
return cast(dict[str, Any], _strip_keys_recursive(sanitized, _STRICT_METADATA_KEYS))
|
||||
|
||||
|
||||
def sanitize_tool_params_for_openai_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for OpenAI strict function calling."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
return cast(
|
||||
dict[str, Any], strip_unsupported_formats(_common_strict_pipeline(params))
|
||||
)
|
||||
|
||||
|
||||
def sanitize_tool_params_for_anthropic_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for Anthropic strict tool use."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
|
||||
sanitized = _strip_keys_recursive(sanitized, _CLAUDE_STRICT_UNSUPPORTED)
|
||||
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
|
||||
|
||||
|
||||
def sanitize_tool_params_for_bedrock_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for Bedrock Converse strict tool use.
|
||||
|
||||
Bedrock Converse uses the same grammar compiler as the underlying Claude
|
||||
model, so the constraints match Anthropic's.
|
||||
"""
|
||||
return sanitize_tool_params_for_anthropic_strict(params)
|
||||
|
||||
|
||||
def generate_model_description(
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
@@ -545,6 +763,25 @@ def build_rich_field_description(prop_schema: dict[str, Any]) -> str:
|
||||
return ". ".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _inline_top_level_ref(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Resolve only the top-level ``$ref``, preserving ``$defs`` for lazy inner resolution.
|
||||
|
||||
Used as a fallback when ``jsonref.replace_refs`` fails on circular schemas.
|
||||
Inner ``$ref`` pointers are left intact so that :func:`_resolve_ref` can
|
||||
resolve them during model construction, with cycle detection via ``in_progress``.
|
||||
"""
|
||||
schema = deepcopy(schema)
|
||||
ref = schema.get("$ref")
|
||||
if isinstance(ref, str) and ref.startswith("#/$defs/"):
|
||||
def_name = ref[len("#/$defs/") :]
|
||||
defs = schema.get("$defs", {})
|
||||
if def_name in defs:
|
||||
resolved: dict[str, Any] = deepcopy(defs[def_name])
|
||||
resolved.setdefault("$defs", defs)
|
||||
return resolved
|
||||
return schema
|
||||
|
||||
|
||||
def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: dict[str, Any],
|
||||
*,
|
||||
@@ -599,19 +836,80 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
>>> person.name
|
||||
'John'
|
||||
"""
|
||||
json_schema = dict(jsonref.replace_refs(json_schema, proxies=False))
|
||||
try:
|
||||
json_schema = dict(jsonref.replace_refs(json_schema, proxies=False))
|
||||
except (jsonref.JsonRefError, RecursionError):
|
||||
json_schema = _inline_top_level_ref(json_schema)
|
||||
|
||||
effective_root = root_schema or json_schema
|
||||
|
||||
json_schema = force_additional_properties_false(json_schema)
|
||||
effective_root = force_additional_properties_false(effective_root)
|
||||
|
||||
in_progress: dict[int, Any] = {}
|
||||
model = _build_model_from_schema(
|
||||
json_schema,
|
||||
effective_root,
|
||||
model_name=model_name,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
__config__=__config__,
|
||||
__base__=__base__,
|
||||
__module__=__module__,
|
||||
__validators__=__validators__,
|
||||
__cls_kwargs__=__cls_kwargs__,
|
||||
)
|
||||
|
||||
types_namespace: dict[str, Any] = {
|
||||
entry.__name__: entry
|
||||
for entry in in_progress.values()
|
||||
if isinstance(entry, type) and issubclass(entry, BaseModel)
|
||||
}
|
||||
for entry in in_progress.values():
|
||||
if (
|
||||
isinstance(entry, type)
|
||||
and issubclass(entry, BaseModel)
|
||||
and not getattr(entry, "__pydantic_complete__", True)
|
||||
):
|
||||
try:
|
||||
entry.model_rebuild(_types_namespace=types_namespace)
|
||||
except Exception as e:
|
||||
logger.debug("model_rebuild failed for %s: %s", entry.__name__, e)
|
||||
return model
|
||||
|
||||
|
||||
def _build_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: dict[str, Any],
|
||||
effective_root: dict[str, Any],
|
||||
*,
|
||||
model_name: str | None,
|
||||
enrich_descriptions: bool,
|
||||
in_progress: dict[int, Any],
|
||||
__config__: ConfigDict | None = None,
|
||||
__base__: type[BaseModel] | None = None,
|
||||
__module__: str = __name__,
|
||||
__validators__: dict[str, AnyClassMethod] | None = None,
|
||||
__cls_kwargs__: dict[str, Any] | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Inner builder shared by the public entry point and recursive nested-object creation.
|
||||
|
||||
Preprocessing via ``jsonref.replace_refs`` and the sanitization walkers is
|
||||
run once by the public entry; this helper walks the already-normalized
|
||||
schema and emits Pydantic models. ``in_progress`` maps ``id(schema)`` to
|
||||
the model being built for that schema, so a cyclic ``$ref`` graph
|
||||
degrades to a ``ForwardRef`` back-edge instead of blowing the stack.
|
||||
"""
|
||||
original_id = id(json_schema)
|
||||
if "allOf" in json_schema:
|
||||
json_schema = _merge_all_of_schemas(json_schema["allOf"], effective_root)
|
||||
if "title" not in json_schema and "title" in (root_schema or {}):
|
||||
json_schema["title"] = (root_schema or {}).get("title")
|
||||
|
||||
effective_name = model_name or json_schema.get("title") or "DynamicModel"
|
||||
|
||||
schema_id = id(json_schema)
|
||||
in_progress[original_id] = effective_name
|
||||
if schema_id != original_id:
|
||||
in_progress[schema_id] = effective_name
|
||||
|
||||
field_definitions = {
|
||||
name: _json_schema_to_pydantic_field(
|
||||
name,
|
||||
@@ -619,13 +917,14 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema.get("required", []),
|
||||
effective_root,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
for name, prop in (json_schema.get("properties", {}) or {}).items()
|
||||
}
|
||||
|
||||
effective_config = __config__ or ConfigDict(extra="forbid")
|
||||
|
||||
return create_model_base(
|
||||
model = create_model_base(
|
||||
effective_name,
|
||||
__config__=effective_config,
|
||||
__base__=__base__,
|
||||
@@ -634,6 +933,10 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
__cls_kwargs__=__cls_kwargs__,
|
||||
**field_definitions,
|
||||
)
|
||||
in_progress[original_id] = model
|
||||
if schema_id != original_id:
|
||||
in_progress[schema_id] = model
|
||||
return model
|
||||
|
||||
|
||||
def _json_schema_to_pydantic_field(
|
||||
@@ -643,6 +946,7 @@ def _json_schema_to_pydantic_field(
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
enrich_descriptions: bool = False,
|
||||
in_progress: dict[int, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema property to a Pydantic field definition.
|
||||
|
||||
@@ -661,6 +965,7 @@ def _json_schema_to_pydantic_field(
|
||||
root_schema,
|
||||
name_=name.title(),
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
is_required = name in required
|
||||
|
||||
@@ -720,7 +1025,7 @@ def _json_schema_to_pydantic_field(
|
||||
field_params["pattern"] = json_schema["pattern"]
|
||||
|
||||
if not is_required:
|
||||
type_ = type_ | None
|
||||
type_ = Optional[type_] # noqa: UP045 - ForwardRef does not support `|`
|
||||
|
||||
if schema_extra:
|
||||
field_params["json_schema_extra"] = schema_extra
|
||||
@@ -793,6 +1098,7 @@ def _json_schema_to_pydantic_type(
|
||||
*,
|
||||
name_: str | None = None,
|
||||
enrich_descriptions: bool = False,
|
||||
in_progress: dict[int, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema to a Python/Pydantic type.
|
||||
|
||||
@@ -801,10 +1107,23 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema: The root schema for resolving $ref.
|
||||
name_: Optional name for nested models.
|
||||
enrich_descriptions: Propagated to nested model creation.
|
||||
in_progress: Map of ``id(schema_dict)`` to the Pydantic model
|
||||
currently being built for that schema, or to a placeholder name
|
||||
as a plain ``str`` while the model is still being constructed.
|
||||
Populated by :func:`_build_model_from_schema`. Enables cycle
|
||||
detection so a self-referential ``$ref`` graph resolves to a
|
||||
:class:`ForwardRef` back-edge rather than recursing forever.
|
||||
|
||||
Returns:
|
||||
A Python type corresponding to the JSON schema.
|
||||
"""
|
||||
if in_progress is not None:
|
||||
cached = in_progress.get(id(json_schema))
|
||||
if isinstance(cached, str):
|
||||
return ForwardRef(cached)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
ref = json_schema.get("$ref")
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
@@ -813,6 +1132,7 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
@@ -832,6 +1152,7 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema,
|
||||
name_=f"{name_ or 'Union'}Option{i}",
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
for i, schema in enumerate(any_of_schemas)
|
||||
]
|
||||
@@ -845,6 +1166,15 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
if in_progress is not None:
|
||||
return _build_model_from_schema(
|
||||
json_schema,
|
||||
root_schema,
|
||||
model_name=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
@@ -852,6 +1182,7 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
@@ -872,12 +1203,21 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
return list
|
||||
if type_ == "object":
|
||||
properties = json_schema.get("properties")
|
||||
if properties:
|
||||
if in_progress is not None:
|
||||
return _build_model_from_schema(
|
||||
json_schema,
|
||||
root_schema,
|
||||
model_name=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
in_progress=in_progress,
|
||||
)
|
||||
json_schema_ = json_schema.copy()
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_ or "DynamicModel"
|
||||
|
||||
@@ -15,6 +15,7 @@ from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.planning_types import PlanStep
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
@@ -481,17 +482,17 @@ class AgentReasoning:
|
||||
"""Get the system prompt for planning.
|
||||
|
||||
Returns:
|
||||
The system prompt, either custom or from i18n.
|
||||
The system prompt, either custom or from I18N_DEFAULT.
|
||||
"""
|
||||
if self.config.system_prompt is not None:
|
||||
return self.config.system_prompt
|
||||
|
||||
# Try new "planning" section first, fall back to "reasoning" for compatibility
|
||||
try:
|
||||
return self.agent.i18n.retrieve("planning", "system_prompt")
|
||||
return I18N_DEFAULT.retrieve("planning", "system_prompt")
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return self.agent.i18n.retrieve("reasoning", "initial_plan").format(
|
||||
return I18N_DEFAULT.retrieve("reasoning", "initial_plan").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self._get_agent_backstory(),
|
||||
@@ -527,7 +528,7 @@ class AgentReasoning:
|
||||
|
||||
# Try new "planning" section first
|
||||
try:
|
||||
return self.agent.i18n.retrieve("planning", "create_plan_prompt").format(
|
||||
return I18N_DEFAULT.retrieve("planning", "create_plan_prompt").format(
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
tools=available_tools,
|
||||
@@ -535,7 +536,7 @@ class AgentReasoning:
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return self.agent.i18n.retrieve("reasoning", "create_plan_prompt").format(
|
||||
return I18N_DEFAULT.retrieve("reasoning", "create_plan_prompt").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self._get_agent_backstory(),
|
||||
@@ -584,12 +585,12 @@ class AgentReasoning:
|
||||
|
||||
# Try new "planning" section first
|
||||
try:
|
||||
return self.agent.i18n.retrieve("planning", "refine_plan_prompt").format(
|
||||
return I18N_DEFAULT.retrieve("planning", "refine_plan_prompt").format(
|
||||
current_plan=current_plan,
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return self.agent.i18n.retrieve("reasoning", "refine_plan_prompt").format(
|
||||
return I18N_DEFAULT.retrieve("reasoning", "refine_plan_prompt").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self._get_agent_backstory(),
|
||||
@@ -642,7 +643,7 @@ def _call_llm_with_reasoning_prompt(
|
||||
Returns:
|
||||
The LLM response.
|
||||
"""
|
||||
system_prompt = reasoning_agent.i18n.retrieve("reasoning", plan_type).format(
|
||||
system_prompt = I18N_DEFAULT.retrieve("reasoning", plan_type).format(
|
||||
role=reasoning_agent.role,
|
||||
goal=reasoning_agent.goal,
|
||||
backstory=backstory,
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, NamedTuple
|
||||
import uuid
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -22,6 +24,13 @@ from crewai.types.streaming import (
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_current_stream_ids: contextvars.ContextVar[tuple[str, ...]] = contextvars.ContextVar(
|
||||
"_current_stream_ids", default=()
|
||||
)
|
||||
|
||||
|
||||
class TaskInfo(TypedDict):
|
||||
"""Task context information for streaming."""
|
||||
|
||||
@@ -41,6 +50,7 @@ class StreamingState(NamedTuple):
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
handler: Callable[[Any, BaseEvent], None]
|
||||
stream_id: str | None = None
|
||||
|
||||
|
||||
def _extract_tool_call_info(
|
||||
@@ -102,6 +112,7 @@ def _create_stream_handler(
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception],
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
stream_id: str | None = None,
|
||||
) -> Callable[[Any, BaseEvent], None]:
|
||||
"""Create a stream handler function.
|
||||
|
||||
@@ -110,21 +121,19 @@ def _create_stream_handler(
|
||||
sync_queue: Synchronous queue for chunks.
|
||||
async_queue: Optional async queue for chunks.
|
||||
loop: Optional event loop for async operations.
|
||||
stream_id: Stream scope ID for concurrent isolation.
|
||||
|
||||
Returns:
|
||||
Handler function that can be registered with the event bus.
|
||||
"""
|
||||
|
||||
def stream_handler(_: Any, event: BaseEvent) -> None:
|
||||
"""Handle LLM stream chunk events and enqueue them.
|
||||
|
||||
Args:
|
||||
_: Event source (unused).
|
||||
event: The event to process.
|
||||
"""
|
||||
if not isinstance(event, LLMStreamChunkEvent):
|
||||
return
|
||||
|
||||
if stream_id is not None and stream_id not in _current_stream_ids.get():
|
||||
return
|
||||
|
||||
chunk = _create_stream_chunk(event, current_task_info)
|
||||
|
||||
if async_queue is not None and loop is not None:
|
||||
@@ -159,10 +168,23 @@ def _finalize_streaming(
|
||||
streaming_output: The streaming output to set the result on.
|
||||
"""
|
||||
_unregister_handler(state.handler)
|
||||
streaming_output._on_cleanup = None
|
||||
if state.result_holder:
|
||||
streaming_output._set_result(state.result_holder[0])
|
||||
|
||||
|
||||
def register_cleanup(
|
||||
streaming_output: CrewStreamingOutput | FlowStreamingOutput,
|
||||
state: StreamingState,
|
||||
) -> None:
|
||||
"""Register a cleanup callback on the streaming output.
|
||||
|
||||
Ensures the event handler is unregistered even if aclose()/close()
|
||||
is called before iteration starts.
|
||||
"""
|
||||
streaming_output._on_cleanup = lambda: _unregister_handler(state.handler)
|
||||
|
||||
|
||||
def create_streaming_state(
|
||||
current_task_info: TaskInfo,
|
||||
result_holder: list[Any],
|
||||
@@ -186,7 +208,11 @@ def create_streaming_state(
|
||||
async_queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
handler = _create_stream_handler(current_task_info, sync_queue, async_queue, loop)
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
handler = _create_stream_handler(
|
||||
current_task_info, sync_queue, async_queue, loop, stream_id=stream_id
|
||||
)
|
||||
crewai_event_bus.register_handler(LLMStreamChunkEvent, handler)
|
||||
|
||||
return StreamingState(
|
||||
@@ -196,6 +222,7 @@ def create_streaming_state(
|
||||
async_queue=async_queue,
|
||||
loop=loop,
|
||||
handler=handler,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -243,7 +270,12 @@ def create_chunk_generator(
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
if state.stream_id is not None:
|
||||
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
|
||||
ctx = contextvars.copy_context()
|
||||
_current_stream_ids.reset(token)
|
||||
else:
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(run_func,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
@@ -283,7 +315,12 @@ async def create_async_chunk_generator(
|
||||
"Async queue not initialized. Use create_streaming_state(use_async=True)."
|
||||
)
|
||||
|
||||
task = asyncio.create_task(run_coro())
|
||||
if state.stream_id is not None:
|
||||
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
|
||||
task = asyncio.create_task(run_coro())
|
||||
_current_stream_ids.reset(token)
|
||||
else:
|
||||
task = asyncio.create_task(run_coro())
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -294,7 +331,14 @@ async def create_async_chunk_generator(
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
await task
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Background streaming task failed", exc_info=True)
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,7 @@ from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageError
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
@@ -30,7 +30,6 @@ if TYPE_CHECKING:
|
||||
async def aexecute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
i18n: I18N,
|
||||
agent_key: str | None = None,
|
||||
agent_role: str | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
@@ -49,7 +48,6 @@ async def aexecute_tool_and_check_finality(
|
||||
Args:
|
||||
agent_action: The action containing the tool to execute.
|
||||
tools: List of available tools.
|
||||
i18n: Internationalization settings.
|
||||
agent_key: Optional key for event emission.
|
||||
agent_role: Optional role for event emission.
|
||||
tools_handler: Optional tools handler for tool execution.
|
||||
@@ -96,7 +94,7 @@ async def aexecute_tool_and_check_finality(
|
||||
if tool:
|
||||
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||
hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_name=sanitized_tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
@@ -120,7 +118,7 @@ async def aexecute_tool_and_check_finality(
|
||||
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_name=sanitized_tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
@@ -142,7 +140,7 @@ async def aexecute_tool_and_check_finality(
|
||||
|
||||
return ToolResult(modified_result, tool.result_as_answer)
|
||||
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool_result = I18N_DEFAULT.errors("wrong_tool_name").format(
|
||||
tool=sanitized_tool_name,
|
||||
tools=", ".join(tool_name_to_tool_map.keys()),
|
||||
)
|
||||
@@ -152,7 +150,6 @@ async def aexecute_tool_and_check_finality(
|
||||
def execute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
i18n: I18N,
|
||||
agent_key: str | None = None,
|
||||
agent_role: str | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
@@ -170,7 +167,6 @@ def execute_tool_and_check_finality(
|
||||
Args:
|
||||
agent_action: The action containing the tool to execute
|
||||
tools: List of available tools
|
||||
i18n: Internationalization settings
|
||||
agent_key: Optional key for event emission
|
||||
agent_role: Optional role for event emission
|
||||
tools_handler: Optional tools handler for tool execution
|
||||
@@ -216,7 +212,7 @@ def execute_tool_and_check_finality(
|
||||
if tool:
|
||||
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||
hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_name=sanitized_tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
@@ -240,7 +236,7 @@ def execute_tool_and_check_finality(
|
||||
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_name=sanitized_tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
@@ -263,7 +259,7 @@ def execute_tool_and_check_finality(
|
||||
|
||||
return ToolResult(modified_result, tool.result_as_answer)
|
||||
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool_result = I18N_DEFAULT.errors("wrong_tool_name").format(
|
||||
tool=sanitized_tool_name,
|
||||
tools=", ".join(tool_name_to_tool_map.keys()),
|
||||
)
|
||||
|
||||
12
lib/crewai/src/crewai/utilities/version.py
Normal file
12
lib/crewai/src/crewai/utilities/version.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Version utilities for crewAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cache
|
||||
import importlib.metadata
|
||||
|
||||
|
||||
@cache
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the installed crewAI version string."""
|
||||
return importlib.metadata.version("crewai")
|
||||
@@ -1208,12 +1208,10 @@ def test_llm_call_with_error():
|
||||
def test_handle_context_length_exceeds_limit():
|
||||
# Import necessary modules
|
||||
from crewai.utilities.agent_utils import handle_context_length
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
# Create mocks for dependencies
|
||||
printer = Printer()
|
||||
i18n = I18N()
|
||||
|
||||
# Create an agent just for its LLM
|
||||
agent = Agent(
|
||||
@@ -1249,7 +1247,6 @@ def test_handle_context_length_exceeds_limit():
|
||||
messages=messages,
|
||||
llm=llm,
|
||||
callbacks=callbacks,
|
||||
i18n=i18n,
|
||||
)
|
||||
|
||||
# Verify our patch was called and raised the correct error
|
||||
@@ -1994,7 +1991,7 @@ def test_litellm_anthropic_error_handling():
|
||||
@pytest.mark.vcr()
|
||||
def test_get_knowledge_search_query():
|
||||
"""Test that _get_knowledge_search_query calls the LLM with the correct prompts."""
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
content = "The capital of France is Paris."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2013,7 +2010,6 @@ def test_get_knowledge_search_query():
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
i18n = I18N()
|
||||
task_prompt = task.prompt()
|
||||
|
||||
with (
|
||||
@@ -2050,13 +2046,13 @@ def test_get_knowledge_search_query():
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": i18n.slice(
|
||||
"content": I18N_DEFAULT.slice(
|
||||
"knowledge_search_query_system_prompt"
|
||||
).format(task_prompt=task.description),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": i18n.slice("knowledge_search_query").format(
|
||||
"content": I18N_DEFAULT.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
),
|
||||
},
|
||||
|
||||
@@ -48,8 +48,6 @@ def _build_executor(**kwargs: Any) -> AgentExecutor:
|
||||
executor._last_context_error = None
|
||||
executor._step_executor = None
|
||||
executor._planner_observer = None
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
executor._i18n = kwargs.get("i18n") or get_i18n()
|
||||
return executor
|
||||
from crewai.agents.planner_observer import PlannerObserver
|
||||
from crewai.experimental.agent_executor import (
|
||||
|
||||
@@ -1051,7 +1051,7 @@ def test_lite_agent_verbose_false_suppresses_printer_output():
|
||||
successful_requests=1,
|
||||
)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
with pytest.warns(FutureWarning):
|
||||
agent = LiteAgent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
|
||||
@@ -55,7 +55,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
- 2.31.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
@@ -63,50 +63,51 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DIqxWpJbbFJoV8WlXhb9UYFbCmdPk\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773385850,\n \"model\": \"gpt-4o-2024-08-06\",\n
|
||||
string: "{\n \"id\": \"chatcmpl-DTApYQx2LepfeRL1XcDKPgrhMFnQr\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1775845516,\n \"model\": \"gpt-4o-2024-08-06\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_G2i9RJGNXKVfnd8ZTaBG8Fwi\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"ask_question_to_coworker\",\n
|
||||
\ \"arguments\": \"{\\\"question\\\": \\\"What are some trending
|
||||
topics or ideas in various fields that could be explored for an article?\\\",
|
||||
\\\"context\\\": \\\"We need to generate a list of 5 interesting ideas to
|
||||
explore for an article. These ideas should be engaging and relevant to current
|
||||
trends or captivating subjects.\\\", \\\"coworker\\\": \\\"Researcher\\\"}\"\n
|
||||
\ }\n },\n {\n \"id\": \"call_j4KH2SGZvNeioql0HcRQ9NTp\",\n
|
||||
\ \"id\": \"call_BCh6lXsBTdixRuRh6OTBPoIJ\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"delegate_work_to_coworker\",\n
|
||||
\ \"arguments\": \"{\\\"task\\\": \\\"Come up with a list of 5
|
||||
interesting ideas to explore for an article.\\\", \\\"context\\\": \\\"We
|
||||
need five intriguing ideas worth exploring for an article. Each idea should
|
||||
have potential for in-depth exploration and appeal to a broad audience, possibly
|
||||
touching on current trends, historical insights, future possibilities, or
|
||||
human interest stories.\\\", \\\"coworker\\\": \\\"Researcher\\\"}\"\n }\n
|
||||
\ },\n {\n \"id\": \"call_rAQFeCrS4ogsqvIWRGAYFHGI\",\n
|
||||
\ \"type\": \"function\",\n \"function\": {\n \"name\":
|
||||
\"ask_question_to_coworker\",\n \"arguments\": \"{\\\"question\\\":
|
||||
\\\"What unique angles or perspectives could we explore to make articles more
|
||||
compelling and engaging?\\\", \\\"context\\\": \\\"Our task involves coming
|
||||
up with 5 ideas for articles, each with an exciting paragraph highlight that
|
||||
illustrates the promise and intrigue of the topic. We want them to be more
|
||||
than generic concepts, shining for readers with fresh insights or engaging
|
||||
twists.\\\", \\\"coworker\\\": \\\"Senior Writer\\\"}\"\n }\n }\n
|
||||
\ ],\n \"refusal\": null,\n \"annotations\": []\n },\n
|
||||
\ \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n }\n
|
||||
\ ],\n \"usage\": {\n \"prompt_tokens\": 476,\n \"completion_tokens\":
|
||||
183,\n \"total_tokens\": 659,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
\"delegate_work_to_coworker\",\n \"arguments\": \"{\\\"task\\\":
|
||||
\\\"Write one amazing paragraph highlight for each of 5 ideas that showcases
|
||||
how good an article about this topic could be.\\\", \\\"context\\\": \\\"Upon
|
||||
receiving five intriguing ideas from the Researcher, create a compelling paragraph
|
||||
for each idea that highlights its potential as a fascinating article. These
|
||||
paragraphs must capture the essence of the topic and explain why it would
|
||||
captivate readers, incorporating possible themes and insights.\\\", \\\"coworker\\\":
|
||||
\\\"Senior Writer\\\"}\"\n }\n }\n ],\n \"refusal\":
|
||||
null,\n \"annotations\": []\n },\n \"logprobs\": null,\n
|
||||
\ \"finish_reason\": \"tool_calls\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
476,\n \"completion_tokens\": 201,\n \"total_tokens\": 677,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_b7c8e3f100\"\n}\n"
|
||||
\"default\",\n \"system_fingerprint\": \"fp_2ca5b70601\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9db9389a3f9e424c-EWR
|
||||
- 9ea3cb06ba66b301-TPE
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 13 Mar 2026 07:10:53 GMT
|
||||
- Fri, 10 Apr 2026 18:25:18 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -122,7 +123,7 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '2402'
|
||||
- '1981'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
@@ -154,13 +155,14 @@ interactions:
|
||||
You work as a freelancer and is now working on doing research and analysis for
|
||||
a new customer.\nYour personal goal is: Make the best research and analysis
|
||||
on content about AI and AI agents"},{"role":"user","content":"\nCurrent Task:
|
||||
What are some trending topics or ideas in various fields that could be explored
|
||||
for an article?\n\nThis is the expected criteria for your final answer: Your
|
||||
best answer to your coworker asking you this, accounting for the context shared.\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nThis
|
||||
is the context you''re working with:\nWe need to generate a list of 5 interesting
|
||||
ideas to explore for an article. These ideas should be engaging and relevant
|
||||
to current trends or captivating subjects.\n\nProvide your complete response:"}],"model":"gpt-4.1-mini"}'
|
||||
Come up with a list of 5 interesting ideas to explore for an article.\n\nThis
|
||||
is the expected criteria for your final answer: Your best answer to your coworker
|
||||
asking you this, accounting for the context shared.\nyou MUST return the actual
|
||||
complete content as the final answer, not a summary.\n\nThis is the context
|
||||
you''re working with:\nWe need five intriguing ideas worth exploring for an
|
||||
article. Each idea should have potential for in-depth exploration and appeal
|
||||
to a broad audience, possibly touching on current trends, historical insights,
|
||||
future possibilities, or human interest stories.\n\nProvide your complete response:"}],"model":"gpt-4.1-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -173,7 +175,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '978'
|
||||
- '1046'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -187,7 +189,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
- 2.31.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
@@ -195,63 +197,69 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DIqxak88AexErt9PGFGHnWPIJLwNV\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773385854,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
string: "{\n \"id\": \"chatcmpl-DTApalbfnYkqIc8slLS3DKwo9KXbc\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1775845518,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Here are five trending and engaging
|
||||
topics across various fields that could be explored for an article:\\n\\n1.
|
||||
**The Rise of Autonomous AI Agents and Their Impact on the Future of Work**
|
||||
\ \\nExplore how autonomous AI agents\u2014systems capable of performing complex
|
||||
tasks independently\u2014are transforming industries such as customer service,
|
||||
software development, and logistics. Discuss implications for job automation,
|
||||
human-AI collaboration, and ethical considerations surrounding decision-making
|
||||
autonomy.\\n\\n2. **Generative AI Beyond Text: Innovations in Audio, Video,
|
||||
and 3D Content Creation** \\nDelve into advancements in generative AI models
|
||||
that create not only text but also realistic audio, video content, virtual
|
||||
environments, and 3D models. Highlight applications in gaming, entertainment,
|
||||
education, and digital marketing, as well as challenges like misinformation
|
||||
and deepfake detection.\\n\\n3. **AI-Driven Climate Modeling: Enhancing Predictive
|
||||
Accuracy to Combat Climate Change** \\nExamine how AI and machine learning
|
||||
are improving climate models by analyzing vast datasets, uncovering patterns,
|
||||
and simulating environmental scenarios. Discuss how these advances are aiding
|
||||
policymakers in making informed decisions to address climate risks and sustainability
|
||||
goals.\\n\\n4. **The Ethical Frontiers of AI in Healthcare: Balancing Innovation
|
||||
with Patient Privacy** \\nInvestigate ethical challenges posed by AI applications
|
||||
in healthcare, including diagnosis, personalized treatment, and patient data
|
||||
management. Focus on balancing rapid technological innovation with privacy,
|
||||
bias mitigation, and regulatory frameworks to ensure equitable access and
|
||||
trust.\\n\\n5. **Quantum Computing Meets AI: Exploring the Next Leap in Computational
|
||||
Power** \\nCover the intersection of quantum computing and artificial intelligence,
|
||||
exploring how quantum algorithms could accelerate AI training processes and
|
||||
solve problems beyond the reach of classical computers. Outline current research,
|
||||
potential breakthroughs, and the timeline for real-world applications.\\n\\nEach
|
||||
of these topics is timely, relevant, and has the potential to engage readers
|
||||
interested in cutting-edge technology, societal impact, and future trends.
|
||||
Let me know if you want me to help develop an outline or deeper research into
|
||||
any of these areas!\",\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 178,\n \"completion_tokens\":
|
||||
402,\n \"total_tokens\": 580,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
\"assistant\",\n \"content\": \"Certainly! Here are five intriguing
|
||||
article ideas that offer rich potential for deep exploration and broad audience
|
||||
appeal, especially aligned with current trends and human interest in AI and
|
||||
technology:\\n\\n1. **The Evolution of AI Agents: From Rule-Based Bots to
|
||||
Autonomous Decision Makers** \\n Explore the historical development of
|
||||
AI agents, tracing the journey from simple scripted chatbots to advanced autonomous
|
||||
systems capable of complex decision-making and learning. Dive into key technological
|
||||
milestones, breakthroughs in machine learning, and current state-of-the-art
|
||||
AI agents. Discuss implications for industries such as customer service, healthcare,
|
||||
and autonomous vehicles, highlighting both opportunities and ethical concerns.\\n\\n2.
|
||||
**AI in Daily Life: How Intelligent Agents Are Reshaping Human Routines**
|
||||
\ \\n Investigate the integration of AI agents in everyday life\u2014from
|
||||
virtual assistants like Siri and Alexa to personalized recommendation systems
|
||||
and smart home devices. Analyze how these AI tools influence productivity,
|
||||
privacy, and social behavior. Include human interest elements through stories
|
||||
of individuals or communities who have embraced or resisted these technologies.\\n\\n3.
|
||||
**The Future of Work: AI Agents as Collaborative Colleagues** \\n Examine
|
||||
how AI agents are transforming workplaces by acting as collaborators rather
|
||||
than just tools. Cover applications in creative fields, data analysis, and
|
||||
decision support, while addressing potential challenges such as job displacement,
|
||||
new skill requirements, and the evolving definition of teamwork. Use expert
|
||||
opinions and case studies to paint a nuanced future outlook.\\n\\n4. **Ethics
|
||||
and Accountability in AI Agent Development** \\n Delve into the ethical
|
||||
dilemmas posed by increasingly autonomous AI agents\u2014topics like bias
|
||||
in algorithms, data privacy, and accountability for AI-driven decisions. Explore
|
||||
measures being taken globally to regulate AI, frameworks for responsible AI
|
||||
development, and the role of public awareness. Include historical context
|
||||
about technology ethics to provide depth.\\n\\n5. **Human-AI Symbiosis: Stories
|
||||
of Innovative Partnerships Shaping Our World** \\n Tell compelling human
|
||||
interest stories about individuals or organizations pioneering collaborative
|
||||
projects with AI agents that lead to breakthroughs in science, art, or social
|
||||
good. Highlight how these partnerships transcend traditional human-machine
|
||||
interaction and open new creative and problem-solving possibilities, inspiring
|
||||
readers about the potential of human-AI synergy.\\n\\nThese ideas are designed
|
||||
to be both engaging and informative, offering multiple angles\u2014technical,
|
||||
historical, ethical, and personal\u2014to keep readers captivated while providing
|
||||
substantial content for in-depth analysis.\",\n \"refusal\": null,\n
|
||||
\ \"annotations\": []\n },\n \"logprobs\": null,\n \"finish_reason\":
|
||||
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 189,\n \"completion_tokens\":
|
||||
472,\n \"total_tokens\": 661,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_e76a310957\"\n}\n"
|
||||
\"default\",\n \"system_fingerprint\": \"fp_fbf43a1ff3\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9db938b0493c4b9f-EWR
|
||||
- 9ea3cb1b5c943323-TPE
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 13 Mar 2026 07:10:59 GMT
|
||||
- Fri, 10 Apr 2026 18:25:25 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -267,7 +275,7 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '5699'
|
||||
- '6990'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
@@ -298,15 +306,16 @@ interactions:
|
||||
a senior writer, specialized in technology, software engineering, AI and startups.
|
||||
You work as a freelancer and are now working on writing content for a new customer.\nYour
|
||||
personal goal is: Write the best content about AI and AI agents."},{"role":"user","content":"\nCurrent
|
||||
Task: What unique angles or perspectives could we explore to make articles more
|
||||
compelling and engaging?\n\nThis is the expected criteria for your final answer:
|
||||
Your best answer to your coworker asking you this, accounting for the context
|
||||
shared.\nyou MUST return the actual complete content as the final answer, not
|
||||
a summary.\n\nThis is the context you''re working with:\nOur task involves coming
|
||||
up with 5 ideas for articles, each with an exciting paragraph highlight that
|
||||
illustrates the promise and intrigue of the topic. We want them to be more than
|
||||
generic concepts, shining for readers with fresh insights or engaging twists.\n\nProvide
|
||||
your complete response:"}],"model":"gpt-4.1-mini"}'
|
||||
Task: Write one amazing paragraph highlight for each of 5 ideas that showcases
|
||||
how good an article about this topic could be.\n\nThis is the expected criteria
|
||||
for your final answer: Your best answer to your coworker asking you this, accounting
|
||||
for the context shared.\nyou MUST return the actual complete content as the
|
||||
final answer, not a summary.\n\nThis is the context you''re working with:\nUpon
|
||||
receiving five intriguing ideas from the Researcher, create a compelling paragraph
|
||||
for each idea that highlights its potential as a fascinating article. These
|
||||
paragraphs must capture the essence of the topic and explain why it would captivate
|
||||
readers, incorporating possible themes and insights.\n\nProvide your complete
|
||||
response:"}],"model":"gpt-4.1-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -319,7 +328,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1041'
|
||||
- '1103'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -333,7 +342,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
- 2.31.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
@@ -341,78 +350,83 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DIqxZCl1kFIE7WXznIKow9QFNZ2QT\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773385853,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
string: "{\n \"id\": \"chatcmpl-DTApbrh9Z4yFAKPHIR48ubdB1R5xK\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1775845519,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Absolutely! To create compelling and
|
||||
engaging AI articles that stand out, we need to go beyond surface-level discussions
|
||||
and deliver fresh perspectives that challenge assumptions and spark curiosity.
|
||||
Here are five unique angles with their highlight paragraphs that could really
|
||||
captivate our readers:\\n\\n1. **The Hidden Psychology of AI Agents: How They
|
||||
Learn Human Biases and What That Means for Our Future** \\n*Highlight:* AI
|
||||
agents don\u2019t just process data\u2014they absorb the subtle nuances and
|
||||
biases embedded in human language, behavior, and culture. This article dives
|
||||
deep into the psychological parallels between AI learning mechanisms and human
|
||||
cognitive biases, revealing surprising ways AI can both mirror and amplify
|
||||
our prejudices. Understanding these dynamics is crucial for building trustworthy
|
||||
AI systems and reshaping the future relationship between humans and machines.\\n\\n2.
|
||||
**From Assistants to Autonomous Creators: The Rise of AI Agents as Artists,
|
||||
Writers, and Innovators** \\n*Highlight:* What do we lose and gain when AI
|
||||
agents start producing original art, literature, and innovations? This piece
|
||||
explores groundbreaking examples where AI isn\u2019t just a tool but a creative
|
||||
partner that challenges our definition of authorship and genius. We\u2019ll
|
||||
examine ethical dilemmas, collaborative workflows, and the exciting frontier
|
||||
where human intuition meets algorithmic originality.\\n\\n3. **AI Agents in
|
||||
the Wild: How Decentralized Autonomous Organizations Could Redefine Economy
|
||||
and Governance** \\n*Highlight:* Imagine AI agents operating autonomously
|
||||
in decentralized networks, making real-time decisions that affect finances,
|
||||
resource management, and governance without human intervention. This article
|
||||
uncovers how DAOs powered by AI agents might spontaneously evolve new forms
|
||||
of organization\u2014transparent, efficient, and resistant to traditional
|
||||
corruption. We\u2019ll investigate early case studies and speculate on how
|
||||
this might disrupt centuries-old societal structures.\\n\\n4. **Beyond Chatbots:
|
||||
The Next Generation of AI Agents as Empathetic Digital Companions** \\n*Highlight:*
|
||||
Moving past scripted conversations, emerging AI agents simulate empathy and
|
||||
emotional intelligence in ways that can transform mental health care, education,
|
||||
and companionship. This article provides an insider look at the complex algorithms
|
||||
and biofeedback mechanisms enabling AI to recognize, respond to, and foster
|
||||
human emotions\u2014potentially filling gaps in underserved populations while
|
||||
raising profound questions about authenticity and connection.\\n\\n5. **The
|
||||
Environmental Toll of AI Agents: Unmasking the Ecological Cost of Intelligent
|
||||
Automation** \\n*Highlight:* While AI promises efficiency and innovation,
|
||||
the environmental footprint of training and deploying millions of AI agents
|
||||
is rarely discussed. This eye-opening article quantifies the energy demands
|
||||
of current models, challenges the narrative of AI as an unequivocal green
|
||||
solution, and explores emerging approaches pathing toward sustainable intelligent
|
||||
automation\u2014an urgent conversation for an increasingly eco-conscious tech
|
||||
landscape.\\n\\nEach of these angles opens a door to rich storytelling that
|
||||
blends technical depth, ethical inquiry, and visionary implications\u2014perfect
|
||||
for readers hungry for insight that\u2019s both sophisticated and accessible.
|
||||
Let me know which ones resonate most, or if you want me to refine any into
|
||||
full article outlines!\",\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 188,\n \"completion_tokens\":
|
||||
595,\n \"total_tokens\": 783,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
\"assistant\",\n \"content\": \"1. **The Rise of Autonomous AI Agents:
|
||||
Revolutionizing Everyday Tasks** \\nImagine a world where AI agents autonomously
|
||||
manage your daily schedule, optimize your work routines, and even handle complex
|
||||
decision-making with minimal human intervention. An article exploring the
|
||||
rise of autonomous AI agents would captivate readers by diving into how advancements
|
||||
in machine learning and natural language processing have matured these agents
|
||||
from simple chatbots to intelligent collaborators. Themes could include practical
|
||||
applications in industries like healthcare, finance, and personal productivity,
|
||||
the challenges of trust and transparency, and a glimpse into the ethical questions
|
||||
surrounding AI autonomy. This topic not only showcases cutting-edge technology
|
||||
but also invites readers to envision the near future of human-AI synergy.\\n\\n2.
|
||||
**Building Ethical AI Agents: Balancing Innovation with Responsibility** \\nAs
|
||||
AI agents become more powerful and independent, the imperative to embed ethical
|
||||
frameworks within their design comes sharply into focus. An insightful article
|
||||
on this theme would engage readers by unpacking the complexities of programming
|
||||
morality, fairness, and accountability into AI systems that influence critical
|
||||
decisions\u2014whether in hiring processes, law enforcement, or digital content
|
||||
moderation. Exploring real-world case studies alongside philosophical and
|
||||
regulatory perspectives, the piece could illuminate the delicate balance between
|
||||
technological innovation and societal values, offering a nuanced discussion
|
||||
that appeals to technologists, ethicists, and everyday users alike.\\n\\n3.
|
||||
**AI Agents in Startups: Accelerating Growth and Disrupting Markets** \\nStartups
|
||||
are uniquely positioned to leverage AI agents as game-changers that turbocharge
|
||||
growth, optimize workflows, and unlock new business models. This article could
|
||||
enthrall readers by detailing how nimble companies integrate AI-driven agents
|
||||
for customer engagement, market analysis, and personalized product recommendations\u2014outpacing
|
||||
larger incumbents. It would also examine hurdles such as data privacy, scaling
|
||||
complexities, and the human-AI collaboration dynamic, providing actionable
|
||||
insights for entrepreneurs and investors. The story of AI agents fueling startup
|
||||
innovation not only inspires but also outlines the practical pathways and
|
||||
pitfalls on the frontier of modern entrepreneurship.\\n\\n4. **The Future
|
||||
of Work with AI Agents: Redefining Roles and Skills** \\nAI agents are redefining
|
||||
professional landscapes by automating routine tasks and augmenting human creativity
|
||||
and decision-making. An article on this topic could engage readers by painting
|
||||
a vivid picture of the evolving workplace, where collaboration between humans
|
||||
and AI agents becomes the norm. Delving into emerging roles, necessary skill
|
||||
sets, and how education and training must adapt, the piece would offer a forward-thinking
|
||||
analysis that resonates deeply with employees, managers, and policymakers.
|
||||
Exploring themes of workforce transformation, productivity gains, and potential
|
||||
socioeconomic impacts, it provides a comprehensive outlook on an AI-integrated
|
||||
work environment.\\n\\n5. **From Reactive to Proactive: How Next-Gen AI Agents
|
||||
Anticipate Needs** \\nThe leap from reactive AI assistants to truly proactive
|
||||
AI agents signifies one of the most thrilling advances in artificial intelligence.
|
||||
An article centered on this evolution would captivate readers by illustrating
|
||||
how these agents utilize predictive analytics, contextual understanding, and
|
||||
continuous learning to anticipate user needs before they are expressed. By
|
||||
showcasing pioneering applications in personalized healthcare management,
|
||||
smart homes, and adaptive learning platforms, the article would highlight
|
||||
the profound shift toward intuitive, anticipatory technology. This theme not
|
||||
only excites with futuristic promise but also probes the technical and privacy
|
||||
challenges that come with increased agency and foresight.\",\n \"refusal\":
|
||||
null,\n \"annotations\": []\n },\n \"logprobs\": null,\n
|
||||
\ \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
197,\n \"completion_tokens\": 666,\n \"total_tokens\": 863,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_ae0f8c9a7b\"\n}\n"
|
||||
\"default\",\n \"system_fingerprint\": \"fp_d45f83c5fd\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9db938b0489680d4-EWR
|
||||
- 9ea3cb1cbfe2b312-TPE
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 13 Mar 2026 07:11:02 GMT
|
||||
- Fri, 10 Apr 2026 18:25:28 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -428,7 +442,7 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '8310'
|
||||
- '9479'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
@@ -467,91 +481,105 @@ interactions:
|
||||
good an article about this topic could be. Return the list of ideas with their
|
||||
paragraph and your notes.\\n\\nThis is the expected criteria for your final
|
||||
answer: 5 bullet points with a paragraph for each idea.\\nyou MUST return the
|
||||
actual complete content as the final answer, not a summary.\"},{\"role\":\"assistant\",\"content\":null,\"tool_calls\":[{\"id\":\"call_G2i9RJGNXKVfnd8ZTaBG8Fwi\",\"type\":\"function\",\"function\":{\"name\":\"ask_question_to_coworker\",\"arguments\":\"{\\\"question\\\":
|
||||
\\\"What are some trending topics or ideas in various fields that could be explored
|
||||
for an article?\\\", \\\"context\\\": \\\"We need to generate a list of 5 interesting
|
||||
ideas to explore for an article. These ideas should be engaging and relevant
|
||||
to current trends or captivating subjects.\\\", \\\"coworker\\\": \\\"Researcher\\\"}\"}},{\"id\":\"call_j4KH2SGZvNeioql0HcRQ9NTp\",\"type\":\"function\",\"function\":{\"name\":\"ask_question_to_coworker\",\"arguments\":\"{\\\"question\\\":
|
||||
\\\"What unique angles or perspectives could we explore to make articles more
|
||||
compelling and engaging?\\\", \\\"context\\\": \\\"Our task involves coming
|
||||
up with 5 ideas for articles, each with an exciting paragraph highlight that
|
||||
illustrates the promise and intrigue of the topic. We want them to be more than
|
||||
generic concepts, shining for readers with fresh insights or engaging twists.\\\",
|
||||
\\\"coworker\\\": \\\"Senior Writer\\\"}\"}}]},{\"role\":\"tool\",\"tool_call_id\":\"call_G2i9RJGNXKVfnd8ZTaBG8Fwi\",\"name\":\"ask_question_to_coworker\",\"content\":\"Here
|
||||
are five trending and engaging topics across various fields that could be explored
|
||||
for an article:\\n\\n1. **The Rise of Autonomous AI Agents and Their Impact
|
||||
on the Future of Work** \\nExplore how autonomous AI agents\u2014systems capable
|
||||
of performing complex tasks independently\u2014are transforming industries such
|
||||
as customer service, software development, and logistics. Discuss implications
|
||||
for job automation, human-AI collaboration, and ethical considerations surrounding
|
||||
decision-making autonomy.\\n\\n2. **Generative AI Beyond Text: Innovations in
|
||||
Audio, Video, and 3D Content Creation** \\nDelve into advancements in generative
|
||||
AI models that create not only text but also realistic audio, video content,
|
||||
virtual environments, and 3D models. Highlight applications in gaming, entertainment,
|
||||
education, and digital marketing, as well as challenges like misinformation
|
||||
and deepfake detection.\\n\\n3. **AI-Driven Climate Modeling: Enhancing Predictive
|
||||
Accuracy to Combat Climate Change** \\nExamine how AI and machine learning
|
||||
are improving climate models by analyzing vast datasets, uncovering patterns,
|
||||
and simulating environmental scenarios. Discuss how these advances are aiding
|
||||
policymakers in making informed decisions to address climate risks and sustainability
|
||||
goals.\\n\\n4. **The Ethical Frontiers of AI in Healthcare: Balancing Innovation
|
||||
with Patient Privacy** \\nInvestigate ethical challenges posed by AI applications
|
||||
in healthcare, including diagnosis, personalized treatment, and patient data
|
||||
management. Focus on balancing rapid technological innovation with privacy,
|
||||
bias mitigation, and regulatory frameworks to ensure equitable access and trust.\\n\\n5.
|
||||
**Quantum Computing Meets AI: Exploring the Next Leap in Computational Power**
|
||||
\ \\nCover the intersection of quantum computing and artificial intelligence,
|
||||
exploring how quantum algorithms could accelerate AI training processes and
|
||||
solve problems beyond the reach of classical computers. Outline current research,
|
||||
potential breakthroughs, and the timeline for real-world applications.\\n\\nEach
|
||||
of these topics is timely, relevant, and has the potential to engage readers
|
||||
interested in cutting-edge technology, societal impact, and future trends. Let
|
||||
me know if you want me to help develop an outline or deeper research into any
|
||||
of these areas!\"},{\"role\":\"tool\",\"tool_call_id\":\"call_j4KH2SGZvNeioql0HcRQ9NTp\",\"name\":\"ask_question_to_coworker\",\"content\":\"Absolutely!
|
||||
To create compelling and engaging AI articles that stand out, we need to go
|
||||
beyond surface-level discussions and deliver fresh perspectives that challenge
|
||||
assumptions and spark curiosity. Here are five unique angles with their highlight
|
||||
paragraphs that could really captivate our readers:\\n\\n1. **The Hidden Psychology
|
||||
of AI Agents: How They Learn Human Biases and What That Means for Our Future**
|
||||
\ \\n*Highlight:* AI agents don\u2019t just process data\u2014they absorb the
|
||||
subtle nuances and biases embedded in human language, behavior, and culture.
|
||||
This article dives deep into the psychological parallels between AI learning
|
||||
mechanisms and human cognitive biases, revealing surprising ways AI can both
|
||||
mirror and amplify our prejudices. Understanding these dynamics is crucial for
|
||||
building trustworthy AI systems and reshaping the future relationship between
|
||||
humans and machines.\\n\\n2. **From Assistants to Autonomous Creators: The Rise
|
||||
of AI Agents as Artists, Writers, and Innovators** \\n*Highlight:* What do
|
||||
we lose and gain when AI agents start producing original art, literature, and
|
||||
innovations? This piece explores groundbreaking examples where AI isn\u2019t
|
||||
just a tool but a creative partner that challenges our definition of authorship
|
||||
and genius. We\u2019ll examine ethical dilemmas, collaborative workflows, and
|
||||
the exciting frontier where human intuition meets algorithmic originality.\\n\\n3.
|
||||
**AI Agents in the Wild: How Decentralized Autonomous Organizations Could Redefine
|
||||
Economy and Governance** \\n*Highlight:* Imagine AI agents operating autonomously
|
||||
in decentralized networks, making real-time decisions that affect finances,
|
||||
resource management, and governance without human intervention. This article
|
||||
uncovers how DAOs powered by AI agents might spontaneously evolve new forms
|
||||
of organization\u2014transparent, efficient, and resistant to traditional corruption.
|
||||
We\u2019ll investigate early case studies and speculate on how this might disrupt
|
||||
centuries-old societal structures.\\n\\n4. **Beyond Chatbots: The Next Generation
|
||||
of AI Agents as Empathetic Digital Companions** \\n*Highlight:* Moving past
|
||||
scripted conversations, emerging AI agents simulate empathy and emotional intelligence
|
||||
in ways that can transform mental health care, education, and companionship.
|
||||
This article provides an insider look at the complex algorithms and biofeedback
|
||||
mechanisms enabling AI to recognize, respond to, and foster human emotions\u2014potentially
|
||||
filling gaps in underserved populations while raising profound questions about
|
||||
authenticity and connection.\\n\\n5. **The Environmental Toll of AI Agents:
|
||||
Unmasking the Ecological Cost of Intelligent Automation** \\n*Highlight:* While
|
||||
AI promises efficiency and innovation, the environmental footprint of training
|
||||
and deploying millions of AI agents is rarely discussed. This eye-opening article
|
||||
quantifies the energy demands of current models, challenges the narrative of
|
||||
AI as an unequivocal green solution, and explores emerging approaches pathing
|
||||
toward sustainable intelligent automation\u2014an urgent conversation for an
|
||||
increasingly eco-conscious tech landscape.\\n\\nEach of these angles opens a
|
||||
door to rich storytelling that blends technical depth, ethical inquiry, and
|
||||
visionary implications\u2014perfect for readers hungry for insight that\u2019s
|
||||
both sophisticated and accessible. Let me know which ones resonate most, or
|
||||
if you want me to refine any into full article outlines!\"},{\"role\":\"user\",\"content\":\"Analyze
|
||||
actual complete content as the final answer, not a summary.\"},{\"role\":\"assistant\",\"content\":null,\"tool_calls\":[{\"id\":\"call_BCh6lXsBTdixRuRh6OTBPoIJ\",\"type\":\"function\",\"function\":{\"name\":\"delegate_work_to_coworker\",\"arguments\":\"{\\\"task\\\":
|
||||
\\\"Come up with a list of 5 interesting ideas to explore for an article.\\\",
|
||||
\\\"context\\\": \\\"We need five intriguing ideas worth exploring for an article.
|
||||
Each idea should have potential for in-depth exploration and appeal to a broad
|
||||
audience, possibly touching on current trends, historical insights, future possibilities,
|
||||
or human interest stories.\\\", \\\"coworker\\\": \\\"Researcher\\\"}\"}},{\"id\":\"call_rAQFeCrS4ogsqvIWRGAYFHGI\",\"type\":\"function\",\"function\":{\"name\":\"delegate_work_to_coworker\",\"arguments\":\"{\\\"task\\\":
|
||||
\\\"Write one amazing paragraph highlight for each of 5 ideas that showcases
|
||||
how good an article about this topic could be.\\\", \\\"context\\\": \\\"Upon
|
||||
receiving five intriguing ideas from the Researcher, create a compelling paragraph
|
||||
for each idea that highlights its potential as a fascinating article. These
|
||||
paragraphs must capture the essence of the topic and explain why it would captivate
|
||||
readers, incorporating possible themes and insights.\\\", \\\"coworker\\\":
|
||||
\\\"Senior Writer\\\"}\"}}]},{\"role\":\"tool\",\"tool_call_id\":\"call_BCh6lXsBTdixRuRh6OTBPoIJ\",\"name\":\"delegate_work_to_coworker\",\"content\":\"Certainly!
|
||||
Here are five intriguing article ideas that offer rich potential for deep exploration
|
||||
and broad audience appeal, especially aligned with current trends and human
|
||||
interest in AI and technology:\\n\\n1. **The Evolution of AI Agents: From Rule-Based
|
||||
Bots to Autonomous Decision Makers** \\n Explore the historical development
|
||||
of AI agents, tracing the journey from simple scripted chatbots to advanced
|
||||
autonomous systems capable of complex decision-making and learning. Dive into
|
||||
key technological milestones, breakthroughs in machine learning, and current
|
||||
state-of-the-art AI agents. Discuss implications for industries such as customer
|
||||
service, healthcare, and autonomous vehicles, highlighting both opportunities
|
||||
and ethical concerns.\\n\\n2. **AI in Daily Life: How Intelligent Agents Are
|
||||
Reshaping Human Routines** \\n Investigate the integration of AI agents in
|
||||
everyday life\u2014from virtual assistants like Siri and Alexa to personalized
|
||||
recommendation systems and smart home devices. Analyze how these AI tools influence
|
||||
productivity, privacy, and social behavior. Include human interest elements
|
||||
through stories of individuals or communities who have embraced or resisted
|
||||
these technologies.\\n\\n3. **The Future of Work: AI Agents as Collaborative
|
||||
Colleagues** \\n Examine how AI agents are transforming workplaces by acting
|
||||
as collaborators rather than just tools. Cover applications in creative fields,
|
||||
data analysis, and decision support, while addressing potential challenges such
|
||||
as job displacement, new skill requirements, and the evolving definition of
|
||||
teamwork. Use expert opinions and case studies to paint a nuanced future outlook.\\n\\n4.
|
||||
**Ethics and Accountability in AI Agent Development** \\n Delve into the
|
||||
ethical dilemmas posed by increasingly autonomous AI agents\u2014topics like
|
||||
bias in algorithms, data privacy, and accountability for AI-driven decisions.
|
||||
Explore measures being taken globally to regulate AI, frameworks for responsible
|
||||
AI development, and the role of public awareness. Include historical context
|
||||
about technology ethics to provide depth.\\n\\n5. **Human-AI Symbiosis: Stories
|
||||
of Innovative Partnerships Shaping Our World** \\n Tell compelling human
|
||||
interest stories about individuals or organizations pioneering collaborative
|
||||
projects with AI agents that lead to breakthroughs in science, art, or social
|
||||
good. Highlight how these partnerships transcend traditional human-machine interaction
|
||||
and open new creative and problem-solving possibilities, inspiring readers about
|
||||
the potential of human-AI synergy.\\n\\nThese ideas are designed to be both
|
||||
engaging and informative, offering multiple angles\u2014technical, historical,
|
||||
ethical, and personal\u2014to keep readers captivated while providing substantial
|
||||
content for in-depth analysis.\"},{\"role\":\"tool\",\"tool_call_id\":\"call_rAQFeCrS4ogsqvIWRGAYFHGI\",\"name\":\"delegate_work_to_coworker\",\"content\":\"1.
|
||||
**The Rise of Autonomous AI Agents: Revolutionizing Everyday Tasks** \\nImagine
|
||||
a world where AI agents autonomously manage your daily schedule, optimize your
|
||||
work routines, and even handle complex decision-making with minimal human intervention.
|
||||
An article exploring the rise of autonomous AI agents would captivate readers
|
||||
by diving into how advancements in machine learning and natural language processing
|
||||
have matured these agents from simple chatbots to intelligent collaborators.
|
||||
Themes could include practical applications in industries like healthcare, finance,
|
||||
and personal productivity, the challenges of trust and transparency, and a glimpse
|
||||
into the ethical questions surrounding AI autonomy. This topic not only showcases
|
||||
cutting-edge technology but also invites readers to envision the near future
|
||||
of human-AI synergy.\\n\\n2. **Building Ethical AI Agents: Balancing Innovation
|
||||
with Responsibility** \\nAs AI agents become more powerful and independent,
|
||||
the imperative to embed ethical frameworks within their design comes sharply
|
||||
into focus. An insightful article on this theme would engage readers by unpacking
|
||||
the complexities of programming morality, fairness, and accountability into
|
||||
AI systems that influence critical decisions\u2014whether in hiring processes,
|
||||
law enforcement, or digital content moderation. Exploring real-world case studies
|
||||
alongside philosophical and regulatory perspectives, the piece could illuminate
|
||||
the delicate balance between technological innovation and societal values, offering
|
||||
a nuanced discussion that appeals to technologists, ethicists, and everyday
|
||||
users alike.\\n\\n3. **AI Agents in Startups: Accelerating Growth and Disrupting
|
||||
Markets** \\nStartups are uniquely positioned to leverage AI agents as game-changers
|
||||
that turbocharge growth, optimize workflows, and unlock new business models.
|
||||
This article could enthrall readers by detailing how nimble companies integrate
|
||||
AI-driven agents for customer engagement, market analysis, and personalized
|
||||
product recommendations\u2014outpacing larger incumbents. It would also examine
|
||||
hurdles such as data privacy, scaling complexities, and the human-AI collaboration
|
||||
dynamic, providing actionable insights for entrepreneurs and investors. The
|
||||
story of AI agents fueling startup innovation not only inspires but also outlines
|
||||
the practical pathways and pitfalls on the frontier of modern entrepreneurship.\\n\\n4.
|
||||
**The Future of Work with AI Agents: Redefining Roles and Skills** \\nAI agents
|
||||
are redefining professional landscapes by automating routine tasks and augmenting
|
||||
human creativity and decision-making. An article on this topic could engage
|
||||
readers by painting a vivid picture of the evolving workplace, where collaboration
|
||||
between humans and AI agents becomes the norm. Delving into emerging roles,
|
||||
necessary skill sets, and how education and training must adapt, the piece would
|
||||
offer a forward-thinking analysis that resonates deeply with employees, managers,
|
||||
and policymakers. Exploring themes of workforce transformation, productivity
|
||||
gains, and potential socioeconomic impacts, it provides a comprehensive outlook
|
||||
on an AI-integrated work environment.\\n\\n5. **From Reactive to Proactive:
|
||||
How Next-Gen AI Agents Anticipate Needs** \\nThe leap from reactive AI assistants
|
||||
to truly proactive AI agents signifies one of the most thrilling advances in
|
||||
artificial intelligence. An article centered on this evolution would captivate
|
||||
readers by illustrating how these agents utilize predictive analytics, contextual
|
||||
understanding, and continuous learning to anticipate user needs before they
|
||||
are expressed. By showcasing pioneering applications in personalized healthcare
|
||||
management, smart homes, and adaptive learning platforms, the article would
|
||||
highlight the profound shift toward intuitive, anticipatory technology. This
|
||||
theme not only excites with futuristic promise but also probes the technical
|
||||
and privacy challenges that come with increased agency and foresight.\"},{\"role\":\"user\",\"content\":\"Analyze
|
||||
the tool result. If requirements are met, provide the Final Answer. Otherwise,
|
||||
call the next tool. Deliver only the answer without meta-commentary.\"}],\"model\":\"gpt-4o\",\"tool_choice\":\"auto\",\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"delegate_work_to_coworker\",\"description\":\"Delegate
|
||||
a specific task to one of the following coworkers: Researcher, Senior Writer\\nThe
|
||||
@@ -582,7 +610,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '9923'
|
||||
- '11056'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
@@ -598,7 +626,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
- 2.31.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
@@ -606,58 +634,64 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DIqxidsfoqQl7qXSIVHfSCyETUwlU\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773385862,\n \"model\": \"gpt-4o-2024-08-06\",\n
|
||||
string: "{\n \"id\": \"chatcmpl-DTApljTaq8nDgNMS21B319i56seCn\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1775845529,\n \"model\": \"gpt-4o-2024-08-06\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"1. **The Rise of Autonomous AI Agents
|
||||
and Their Impact on the Future of Work** \\nExplore how autonomous AI agents\u2014systems
|
||||
capable of performing complex tasks independently\u2014are transforming industries
|
||||
such as customer service, software development, and logistics. Discuss implications
|
||||
for job automation, human-AI collaboration, and ethical considerations surrounding
|
||||
decision-making autonomy.\\n\\n2. **Generative AI Beyond Text: Innovations
|
||||
in Audio, Video, and 3D Content Creation** \\nDelve into advancements in
|
||||
generative AI models that create not only text but also realistic audio, video
|
||||
content, virtual environments, and 3D models. Highlight applications in gaming,
|
||||
entertainment, education, and digital marketing, as well as challenges like
|
||||
misinformation and deepfake detection.\\n\\n3. **AI-Driven Climate Modeling:
|
||||
Enhancing Predictive Accuracy to Combat Climate Change** \\nExamine how AI
|
||||
and machine learning are improving climate models by analyzing vast datasets,
|
||||
uncovering patterns, and simulating environmental scenarios. Discuss how these
|
||||
advances are aiding policymakers in making informed decisions to address climate
|
||||
risks and sustainability goals.\\n\\n4. **The Ethical Frontiers of AI in Healthcare:
|
||||
Balancing Innovation with Patient Privacy** \\nInvestigate ethical challenges
|
||||
posed by AI applications in healthcare, including diagnosis, personalized
|
||||
treatment, and patient data management. Focus on balancing rapid technological
|
||||
innovation with privacy, bias mitigation, and regulatory frameworks to ensure
|
||||
equitable access and trust.\\n\\n5. **Quantum Computing Meets AI: Exploring
|
||||
the Next Leap in Computational Power** \\nCover the intersection of quantum
|
||||
computing and artificial intelligence, exploring how quantum algorithms could
|
||||
accelerate AI training processes and solve problems beyond the reach of classical
|
||||
computers. Outline current research, potential breakthroughs, and the timeline
|
||||
for real-world applications.\",\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1748,\n \"completion_tokens\":
|
||||
335,\n \"total_tokens\": 2083,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
\"assistant\",\n \"content\": \"- **The Evolution of AI Agents: From
|
||||
Rule-Based Bots to Autonomous Decision Makers** \\n Explore the historical
|
||||
development of AI agents, tracing the journey from simple scripted chatbots
|
||||
to advanced autonomous systems capable of complex decision-making and learning.
|
||||
Dive into key technological milestones, breakthroughs in machine learning,
|
||||
and current state-of-the-art AI agents. Discuss implications for industries
|
||||
such as customer service, healthcare, and autonomous vehicles, highlighting
|
||||
both opportunities and ethical concerns.\\n\\n- **AI in Daily Life: How Intelligent
|
||||
Agents Are Reshaping Human Routines** \\n Investigate the integration of
|
||||
AI agents in everyday life\u2014from virtual assistants like Siri and Alexa
|
||||
to personalized recommendation systems and smart home devices. Analyze how
|
||||
these AI tools influence productivity, privacy, and social behavior. Include
|
||||
human interest elements through stories of individuals or communities who
|
||||
have embraced or resisted these technologies.\\n\\n- **The Future of Work:
|
||||
AI Agents as Collaborative Colleagues** \\n Examine how AI agents are transforming
|
||||
workplaces by acting as collaborators rather than just tools. Cover applications
|
||||
in creative fields, data analysis, and decision support, while addressing
|
||||
potential challenges such as job displacement, new skill requirements, and
|
||||
the evolving definition of teamwork. Use expert opinions and case studies
|
||||
to paint a nuanced future outlook.\\n\\n- **Ethics and Accountability in AI
|
||||
Agent Development** \\n Delve into the ethical dilemmas posed by increasingly
|
||||
autonomous AI agents\u2014topics like bias in algorithms, data privacy, and
|
||||
accountability for AI-driven decisions. Explore measures being taken globally
|
||||
to regulate AI, frameworks for responsible AI development, and the role of
|
||||
public awareness. Include historical context about technology ethics to provide
|
||||
depth.\\n\\n- **Human-AI Symbiosis: Stories of Innovative Partnerships Shaping
|
||||
Our World** \\n Tell compelling human interest stories about individuals
|
||||
or organizations pioneering collaborative projects with AI agents that lead
|
||||
to breakthroughs in science, art, or social good. Highlight how these partnerships
|
||||
transcend traditional human-machine interaction and open new creative and
|
||||
problem-solving possibilities, inspiring readers about the potential of human-AI
|
||||
synergy.\",\n \"refusal\": null,\n \"annotations\": []\n },\n
|
||||
\ \"logprobs\": null,\n \"finish_reason\": \"stop\"\n }\n ],\n
|
||||
\ \"usage\": {\n \"prompt_tokens\": 1903,\n \"completion_tokens\": 399,\n
|
||||
\ \"total_tokens\": 2302,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_b7c8e3f100\"\n}\n"
|
||||
\"default\",\n \"system_fingerprint\": \"fp_df40ab6c25\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9db938e60d5bc5e7-EWR
|
||||
- 9ea3cb5a6957b301-TPE
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 13 Mar 2026 07:11:04 GMT
|
||||
- Fri, 10 Apr 2026 18:25:31 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -673,7 +707,7 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '2009'
|
||||
- '2183'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
|
||||
@@ -125,7 +125,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_uuid.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy(uuid="test-uuid")
|
||||
self.deploy_command.deploy(uuid="test-uuid", skip_validate=True)
|
||||
|
||||
self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
@@ -137,7 +137,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_name.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy()
|
||||
self.deploy_command.deploy(skip_validate=True)
|
||||
|
||||
self.mock_client.deploy_by_name.assert_called_once_with("test_project")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
@@ -156,7 +156,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
self.mock_client.create_crew.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.create_crew()
|
||||
self.deploy_command.create_crew(skip_validate=True)
|
||||
self.assertIn("Deployment created successfully!", fake_out.getvalue())
|
||||
self.assertIn("new-uuid", fake_out.getvalue())
|
||||
|
||||
|
||||
430
lib/crewai/tests/cli/deploy/test_validate.py
Normal file
430
lib/crewai/tests/cli/deploy/test_validate.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Tests for `crewai.cli.deploy.validate`.
|
||||
|
||||
The fixtures here correspond 1:1 to the deployment-failure patterns observed
|
||||
in the #crewai-deployment-failures Slack channel that motivated this work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Iterable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.cli.deploy.validate import (
|
||||
DeployValidator,
|
||||
Severity,
|
||||
normalize_package_name,
|
||||
)
|
||||
|
||||
|
||||
def _make_pyproject(
|
||||
name: str = "my_crew",
|
||||
dependencies: Iterable[str] = ("crewai>=1.14.0",),
|
||||
*,
|
||||
hatchling: bool = False,
|
||||
flow: bool = False,
|
||||
extra: str = "",
|
||||
) -> str:
|
||||
deps = ", ".join(f'"{d}"' for d in dependencies)
|
||||
lines = [
|
||||
"[project]",
|
||||
f'name = "{name}"',
|
||||
'version = "0.1.0"',
|
||||
f"dependencies = [{deps}]",
|
||||
]
|
||||
if hatchling:
|
||||
lines += [
|
||||
"",
|
||||
"[build-system]",
|
||||
'requires = ["hatchling"]',
|
||||
'build-backend = "hatchling.build"',
|
||||
]
|
||||
if flow:
|
||||
lines += ["", "[tool.crewai]", 'type = "flow"']
|
||||
if extra:
|
||||
lines += ["", extra]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _scaffold_standard_crew(
|
||||
root: Path,
|
||||
*,
|
||||
name: str = "my_crew",
|
||||
include_crew_py: bool = True,
|
||||
include_agents_yaml: bool = True,
|
||||
include_tasks_yaml: bool = True,
|
||||
include_lockfile: bool = True,
|
||||
pyproject: str | None = None,
|
||||
) -> Path:
|
||||
(root / "pyproject.toml").write_text(pyproject or _make_pyproject(name=name))
|
||||
if include_lockfile:
|
||||
(root / "uv.lock").write_text("# dummy uv lockfile\n")
|
||||
|
||||
pkg_dir = root / "src" / normalize_package_name(name)
|
||||
pkg_dir.mkdir(parents=True)
|
||||
(pkg_dir / "__init__.py").write_text("")
|
||||
|
||||
if include_crew_py:
|
||||
(pkg_dir / "crew.py").write_text(
|
||||
dedent(
|
||||
"""
|
||||
from crewai.project import CrewBase, crew
|
||||
|
||||
@CrewBase
|
||||
class MyCrew:
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
from crewai import Crew
|
||||
return Crew(agents=[], tasks=[])
|
||||
"""
|
||||
).strip()
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config_dir = pkg_dir / "config"
|
||||
config_dir.mkdir()
|
||||
if include_agents_yaml:
|
||||
(config_dir / "agents.yaml").write_text("{}\n")
|
||||
if include_tasks_yaml:
|
||||
(config_dir / "tasks.yaml").write_text("{}\n")
|
||||
|
||||
return pkg_dir
|
||||
|
||||
|
||||
def _codes(validator: DeployValidator) -> set[str]:
|
||||
return {r.code for r in validator.results}
|
||||
|
||||
|
||||
def _run_without_import_check(root: Path) -> DeployValidator:
|
||||
"""Run validation with the subprocess-based import check stubbed out;
|
||||
the classifier is exercised directly in its own tests below."""
|
||||
with patch.object(DeployValidator, "_check_module_imports", lambda self: None):
|
||||
v = DeployValidator(project_root=root)
|
||||
v.run()
|
||||
return v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_name, expected",
|
||||
[
|
||||
("my-crew", "my_crew"),
|
||||
("My Cool-Project", "my_cool_project"),
|
||||
("crew123", "crew123"),
|
||||
("crew.name!with$chars", "crewnamewithchars"),
|
||||
],
|
||||
)
|
||||
def test_normalize_package_name(project_name: str, expected: str) -> None:
|
||||
assert normalize_package_name(project_name) == expected
|
||||
|
||||
|
||||
def test_valid_standard_crew_project_passes(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert v.ok, f"expected clean run, got {v.results}"
|
||||
|
||||
|
||||
def test_missing_pyproject_errors(tmp_path: Path) -> None:
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_pyproject" in _codes(v)
|
||||
assert not v.ok
|
||||
|
||||
|
||||
def test_invalid_pyproject_errors(tmp_path: Path) -> None:
|
||||
(tmp_path / "pyproject.toml").write_text("this is not valid toml ====\n")
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "invalid_pyproject" in _codes(v)
|
||||
|
||||
|
||||
def test_missing_project_name_errors(tmp_path: Path) -> None:
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[project]\nversion = "0.1.0"\ndependencies = ["crewai>=1.14.0"]\n'
|
||||
)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_project_name" in _codes(v)
|
||||
|
||||
|
||||
def test_missing_lockfile_errors(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path, include_lockfile=False)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_lockfile" in _codes(v)
|
||||
|
||||
|
||||
def test_poetry_lock_is_accepted(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path, include_lockfile=False)
|
||||
(tmp_path / "poetry.lock").write_text("# poetry lockfile\n")
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_lockfile" not in _codes(v)
|
||||
|
||||
|
||||
def test_stale_lockfile_warns(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path)
|
||||
# Make lockfile older than pyproject.
|
||||
lock = tmp_path / "uv.lock"
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
old_time = pyproject.stat().st_mtime - 60
|
||||
import os
|
||||
|
||||
os.utime(lock, (old_time, old_time))
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "stale_lockfile" in _codes(v)
|
||||
# Stale is a warning, so the run can still be ok (no errors).
|
||||
assert v.ok
|
||||
|
||||
|
||||
def test_missing_package_dir_errors(tmp_path: Path) -> None:
|
||||
# pyproject says name=my_crew but we only create src/other_pkg/
|
||||
(tmp_path / "pyproject.toml").write_text(_make_pyproject(name="my_crew"))
|
||||
(tmp_path / "uv.lock").write_text("")
|
||||
(tmp_path / "src" / "other_pkg").mkdir(parents=True)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
codes = _codes(v)
|
||||
assert "missing_package_dir" in codes
|
||||
finding = next(r for r in v.results if r.code == "missing_package_dir")
|
||||
assert "other_pkg" in finding.hint
|
||||
|
||||
|
||||
def test_egg_info_only_errors_with_targeted_hint(tmp_path: Path) -> None:
|
||||
"""Regression for the case where only src/<name>.egg-info/ exists."""
|
||||
(tmp_path / "pyproject.toml").write_text(_make_pyproject(name="odoo_pm_agents"))
|
||||
(tmp_path / "uv.lock").write_text("")
|
||||
(tmp_path / "src" / "odoo_pm_agents.egg-info").mkdir(parents=True)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
finding = next(r for r in v.results if r.code == "missing_package_dir")
|
||||
assert "egg-info" in finding.hint
|
||||
|
||||
|
||||
def test_stale_egg_info_sibling_warns(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path)
|
||||
(tmp_path / "src" / "my_crew.egg-info").mkdir()
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "stale_egg_info" in _codes(v)
|
||||
|
||||
|
||||
def test_missing_crew_py_errors(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path, include_crew_py=False)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_crew_py" in _codes(v)
|
||||
|
||||
|
||||
def test_missing_agents_yaml_errors(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path, include_agents_yaml=False)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_agents_yaml" in _codes(v)
|
||||
|
||||
|
||||
def test_missing_tasks_yaml_errors(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path, include_tasks_yaml=False)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_tasks_yaml" in _codes(v)
|
||||
|
||||
|
||||
def test_flow_project_requires_main_py(tmp_path: Path) -> None:
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
_make_pyproject(name="my_flow", flow=True)
|
||||
)
|
||||
(tmp_path / "uv.lock").write_text("")
|
||||
(tmp_path / "src" / "my_flow").mkdir(parents=True)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_flow_main" in _codes(v)
|
||||
|
||||
|
||||
def test_flow_project_with_main_py_passes(tmp_path: Path) -> None:
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
_make_pyproject(name="my_flow", flow=True)
|
||||
)
|
||||
(tmp_path / "uv.lock").write_text("")
|
||||
pkg = tmp_path / "src" / "my_flow"
|
||||
pkg.mkdir(parents=True)
|
||||
(pkg / "main.py").write_text("# flow entrypoint\n")
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "missing_flow_main" not in _codes(v)
|
||||
|
||||
|
||||
def test_hatchling_without_wheel_config_passes_when_pkg_dir_matches(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_scaffold_standard_crew(
|
||||
tmp_path, pyproject=_make_pyproject(name="my_crew", hatchling=True)
|
||||
)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
# src/my_crew/ exists, so hatch default should find it — no wheel error.
|
||||
assert "hatch_wheel_target_missing" not in _codes(v)
|
||||
|
||||
|
||||
def test_hatchling_with_explicit_wheel_config_passes(tmp_path: Path) -> None:
|
||||
extra = (
|
||||
"[tool.hatch.build.targets.wheel]\n"
|
||||
'packages = ["src/my_crew"]'
|
||||
)
|
||||
_scaffold_standard_crew(
|
||||
tmp_path,
|
||||
pyproject=_make_pyproject(name="my_crew", hatchling=True, extra=extra),
|
||||
)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "hatch_wheel_target_missing" not in _codes(v)
|
||||
|
||||
|
||||
def test_classify_missing_openai_key_is_warning(tmp_path: Path) -> None:
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error(
|
||||
"ImportError",
|
||||
"Error importing native provider: 1 validation error for OpenAICompletion\n"
|
||||
" Value error, OPENAI_API_KEY is required",
|
||||
tb="",
|
||||
)
|
||||
assert len(v.results) == 1
|
||||
result = v.results[0]
|
||||
assert result.code == "llm_init_missing_key"
|
||||
assert result.severity is Severity.WARNING
|
||||
assert "OPENAI_API_KEY" in result.title
|
||||
|
||||
|
||||
def test_classify_azure_extra_missing_is_error(tmp_path: Path) -> None:
|
||||
"""The real message raised by the Azure provider module uses plain
|
||||
double quotes around the install command (no backticks). Match the
|
||||
exact string that ships in the provider source so this test actually
|
||||
guards the regex used in production."""
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error(
|
||||
"ImportError",
|
||||
'Azure AI Inference native provider not available, to install: uv add "crewai[azure-ai-inference]"',
|
||||
tb="",
|
||||
)
|
||||
assert "missing_provider_extra" in _codes(v)
|
||||
finding = next(r for r in v.results if r.code == "missing_provider_extra")
|
||||
assert finding.title.startswith("Azure AI Inference")
|
||||
assert 'uv add "crewai[azure-ai-inference]"' in finding.hint
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pkg_label, install_cmd",
|
||||
[
|
||||
("Anthropic", 'uv add "crewai[anthropic]"'),
|
||||
("AWS Bedrock", 'uv add "crewai[bedrock]"'),
|
||||
("Google Gen AI", 'uv add "crewai[google-genai]"'),
|
||||
],
|
||||
)
|
||||
def test_classify_missing_provider_extra_matches_real_messages(
|
||||
tmp_path: Path, pkg_label: str, install_cmd: str
|
||||
) -> None:
|
||||
"""Regression for the four provider error strings verbatim."""
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error(
|
||||
"ImportError",
|
||||
f"{pkg_label} native provider not available, to install: {install_cmd}",
|
||||
tb="",
|
||||
)
|
||||
assert "missing_provider_extra" in _codes(v)
|
||||
finding = next(r for r in v.results if r.code == "missing_provider_extra")
|
||||
assert install_cmd in finding.hint
|
||||
|
||||
|
||||
def test_classify_keyerror_at_import_is_warning(tmp_path: Path) -> None:
|
||||
"""Regression for `KeyError: 'SERPLY_API_KEY'` raised at import time."""
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error("KeyError", "'SERPLY_API_KEY'", tb="")
|
||||
codes = _codes(v)
|
||||
assert "env_var_read_at_import" in codes
|
||||
|
||||
|
||||
def test_classify_no_crewbase_class_is_error(tmp_path: Path) -> None:
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error(
|
||||
"ValueError",
|
||||
"Crew class annotated with @CrewBase not found.",
|
||||
tb="",
|
||||
)
|
||||
assert "no_crewbase_class" in _codes(v)
|
||||
|
||||
|
||||
def test_classify_no_flow_subclass_is_error(tmp_path: Path) -> None:
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error("ValueError", "No Flow subclass found in the module.", tb="")
|
||||
assert "no_flow_subclass" in _codes(v)
|
||||
|
||||
|
||||
def test_classify_stale_crewai_pin_attribute_error(tmp_path: Path) -> None:
|
||||
"""Regression for a stale crewai pin missing `_load_response_format`."""
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error(
|
||||
"AttributeError",
|
||||
"'EmploymentServiceDecisionSupportSystemCrew' object has no attribute '_load_response_format'",
|
||||
tb="",
|
||||
)
|
||||
assert "stale_crewai_pin" in _codes(v)
|
||||
|
||||
|
||||
def test_classify_unknown_error_is_fallback(tmp_path: Path) -> None:
|
||||
v = DeployValidator(project_root=tmp_path)
|
||||
v._classify_import_error("RuntimeError", "something weird happened", tb="")
|
||||
assert "import_failed" in _codes(v)
|
||||
|
||||
|
||||
def test_env_var_referenced_but_missing_warns(tmp_path: Path) -> None:
|
||||
pkg = _scaffold_standard_crew(tmp_path)
|
||||
(pkg / "tools.py").write_text(
|
||||
'import os\nkey = os.getenv("TAVILY_API_KEY")\n'
|
||||
)
|
||||
import os
|
||||
|
||||
# Make sure the test doesn't inherit the key from the host environment.
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("TAVILY_API_KEY", None)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
codes = _codes(v)
|
||||
assert "env_vars_not_in_dotenv" in codes
|
||||
|
||||
|
||||
def test_env_var_in_dotenv_does_not_warn(tmp_path: Path) -> None:
|
||||
pkg = _scaffold_standard_crew(tmp_path)
|
||||
(pkg / "tools.py").write_text(
|
||||
'import os\nkey = os.getenv("TAVILY_API_KEY")\n'
|
||||
)
|
||||
(tmp_path / ".env").write_text("TAVILY_API_KEY=abc\n")
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "env_vars_not_in_dotenv" not in _codes(v)
|
||||
|
||||
|
||||
def test_old_crewai_pin_in_uv_lock_warns(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path)
|
||||
(tmp_path / "uv.lock").write_text(
|
||||
'name = "crewai"\nversion = "1.10.0"\nsource = { registry = "..." }\n'
|
||||
)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "old_crewai_pin" in _codes(v)
|
||||
|
||||
|
||||
def test_modern_crewai_pin_does_not_warn(tmp_path: Path) -> None:
|
||||
_scaffold_standard_crew(tmp_path)
|
||||
(tmp_path / "uv.lock").write_text(
|
||||
'name = "crewai"\nversion = "1.14.1"\nsource = { registry = "..." }\n'
|
||||
)
|
||||
v = _run_without_import_check(tmp_path)
|
||||
assert "old_crewai_pin" not in _codes(v)
|
||||
|
||||
|
||||
def test_create_crew_aborts_on_validation_error(tmp_path: Path) -> None:
|
||||
"""`crewai deploy create` must not contact the API when validation fails."""
|
||||
from unittest.mock import MagicMock, patch as mock_patch
|
||||
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
with (
|
||||
mock_patch("crewai.cli.command.get_auth_token", return_value="tok"),
|
||||
mock_patch("crewai.cli.deploy.main.get_project_name", return_value="p"),
|
||||
mock_patch("crewai.cli.command.PlusAPI") as mock_api,
|
||||
mock_patch(
|
||||
"crewai.cli.deploy.main.validate_project"
|
||||
) as mock_validate,
|
||||
):
|
||||
mock_validate.return_value = MagicMock(ok=False)
|
||||
cmd = DeployCommand()
|
||||
cmd.create_crew()
|
||||
assert not cmd.plus_api_client.create_crew.called
|
||||
del mock_api # silence unused-var lint
|
||||
0
lib/crewai/tests/cli/remote_template/__init__.py
Normal file
0
lib/crewai/tests/cli/remote_template/__init__.py
Normal file
283
lib/crewai/tests/cli/remote_template/test_main.py
Normal file
283
lib/crewai/tests/cli/remote_template/test_main.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from crewai.cli.cli import template_add, template_list
|
||||
from crewai.cli.remote_template.main import TemplateCommand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
SAMPLE_REPOS = [
|
||||
{"name": "template_deep_research", "description": "Deep research template", "private": False},
|
||||
{"name": "template_pull_request_review", "description": "PR review template", "private": False},
|
||||
{"name": "template_conversational_example", "description": "Conversational demo", "private": False},
|
||||
{"name": "crewai", "description": "Main repo", "private": False},
|
||||
{"name": "marketplace-crew-template", "description": "Marketplace", "private": False},
|
||||
]
|
||||
|
||||
|
||||
def _make_zipball(files: dict[str, str], top_dir: str = "crewAIInc-template_test-abc123") -> bytes:
|
||||
"""Create an in-memory zipball mimicking GitHub's format."""
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr(f"{top_dir}/", "")
|
||||
for path, content in files.items():
|
||||
zf.writestr(f"{top_dir}/{path}", content)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
# --- CLI command tests ---
|
||||
|
||||
|
||||
@patch("crewai.cli.cli.TemplateCommand")
|
||||
def test_template_list_command(mock_cls, runner):
|
||||
mock_instance = MagicMock()
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
result = runner.invoke(template_list)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cls.assert_called_once()
|
||||
mock_instance.list_templates.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.cli.TemplateCommand")
|
||||
def test_template_add_command(mock_cls, runner):
|
||||
mock_instance = MagicMock()
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
result = runner.invoke(template_add, ["deep_research"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cls.assert_called_once()
|
||||
mock_instance.add_template.assert_called_once_with("deep_research", None)
|
||||
|
||||
|
||||
@patch("crewai.cli.cli.TemplateCommand")
|
||||
def test_template_add_with_output_dir(mock_cls, runner):
|
||||
mock_instance = MagicMock()
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
result = runner.invoke(template_add, ["deep_research", "-o", "my_project"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_instance.add_template.assert_called_once_with("deep_research", "my_project")
|
||||
|
||||
|
||||
# --- TemplateCommand unit tests ---
|
||||
|
||||
|
||||
class TestTemplateCommand:
|
||||
@pytest.fixture
|
||||
def cmd(self):
|
||||
with patch.object(TemplateCommand, "__init__", return_value=None):
|
||||
instance = TemplateCommand()
|
||||
instance._telemetry = MagicMock()
|
||||
return instance
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_fetch_templates_filters_by_prefix(self, mock_get, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
# Return empty on page 2 to stop pagination
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
templates = cmd._fetch_templates()
|
||||
|
||||
assert len(templates) == 3
|
||||
assert all(t["name"].startswith("template_") for t in templates)
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_fetch_templates_excludes_private(self, mock_get, cmd):
|
||||
repos = [
|
||||
{"name": "template_private_one", "description": "", "private": True},
|
||||
{"name": "template_public_one", "description": "", "private": False},
|
||||
]
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = repos
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
templates = cmd._fetch_templates()
|
||||
|
||||
assert len(templates) == 1
|
||||
assert templates[0]["name"] == "template_public_one"
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_fetch_templates_api_error(self, mock_get, cmd):
|
||||
mock_get.side_effect = httpx.HTTPError("connection error")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cmd._fetch_templates()
|
||||
|
||||
@patch("crewai.cli.remote_template.main.click.prompt", return_value="q")
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_list_templates_prints_output(self, mock_get, mock_prompt, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
with patch("crewai.cli.remote_template.main.console") as mock_console:
|
||||
cmd.list_templates()
|
||||
assert mock_console.print.call_count > 0
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_resolve_repo_name_with_prefix(self, mock_get, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
result = cmd._resolve_repo_name("template_deep_research")
|
||||
assert result == "template_deep_research"
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_resolve_repo_name_without_prefix(self, mock_get, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
result = cmd._resolve_repo_name("deep_research")
|
||||
assert result == "template_deep_research"
|
||||
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_resolve_repo_name_not_found(self, mock_get, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
result = cmd._resolve_repo_name("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_extract_zip(self, cmd, tmp_path):
|
||||
files = {
|
||||
"README.md": "# Test Template",
|
||||
"src/main.py": "print('hello')",
|
||||
"config/settings.yaml": "key: value",
|
||||
}
|
||||
zip_bytes = _make_zipball(files)
|
||||
dest = str(tmp_path / "output")
|
||||
|
||||
cmd._extract_zip(zip_bytes, dest)
|
||||
|
||||
assert os.path.isfile(os.path.join(dest, "README.md"))
|
||||
assert os.path.isfile(os.path.join(dest, "src", "main.py"))
|
||||
assert os.path.isfile(os.path.join(dest, "config", "settings.yaml"))
|
||||
|
||||
with open(os.path.join(dest, "src", "main.py")) as f:
|
||||
assert f.read() == "print('hello')"
|
||||
|
||||
@patch.object(TemplateCommand, "_extract_zip")
|
||||
@patch.object(TemplateCommand, "_download_zip")
|
||||
@patch.object(TemplateCommand, "_resolve_repo_name")
|
||||
def test_add_template_success(self, mock_resolve, mock_download, mock_extract, cmd, tmp_path):
|
||||
mock_resolve.return_value = "template_deep_research"
|
||||
mock_download.return_value = b"fake-zip-bytes"
|
||||
|
||||
os.chdir(tmp_path)
|
||||
cmd.add_template("deep_research")
|
||||
|
||||
mock_resolve.assert_called_once_with("deep_research")
|
||||
mock_download.assert_called_once_with("template_deep_research")
|
||||
expected_dest = os.path.join(str(tmp_path), "deep_research")
|
||||
mock_extract.assert_called_once_with(b"fake-zip-bytes", expected_dest)
|
||||
|
||||
@patch.object(TemplateCommand, "_resolve_repo_name")
|
||||
def test_add_template_not_found(self, mock_resolve, cmd):
|
||||
mock_resolve.return_value = None
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cmd.add_template("nonexistent")
|
||||
|
||||
@patch.object(TemplateCommand, "_extract_zip")
|
||||
@patch.object(TemplateCommand, "_download_zip")
|
||||
@patch("crewai.cli.remote_template.main.click.prompt", return_value="my_project")
|
||||
@patch.object(TemplateCommand, "_resolve_repo_name")
|
||||
def test_add_template_dir_exists_prompts_rename(self, mock_resolve, mock_prompt, mock_download, mock_extract, cmd, tmp_path):
|
||||
mock_resolve.return_value = "template_deep_research"
|
||||
mock_download.return_value = b"fake-zip-bytes"
|
||||
existing = tmp_path / "deep_research"
|
||||
existing.mkdir()
|
||||
|
||||
os.chdir(tmp_path)
|
||||
cmd.add_template("deep_research")
|
||||
|
||||
expected_dest = os.path.join(str(tmp_path), "my_project")
|
||||
mock_extract.assert_called_once_with(b"fake-zip-bytes", expected_dest)
|
||||
|
||||
@patch.object(TemplateCommand, "_resolve_repo_name")
|
||||
@patch("crewai.cli.remote_template.main.click.prompt", return_value="q")
|
||||
def test_add_template_dir_exists_quit(self, mock_prompt, mock_resolve, cmd, tmp_path):
|
||||
mock_resolve.return_value = "template_deep_research"
|
||||
existing = tmp_path / "deep_research"
|
||||
existing.mkdir()
|
||||
|
||||
os.chdir(tmp_path)
|
||||
cmd.add_template("deep_research")
|
||||
# Should return without downloading
|
||||
|
||||
@patch.object(TemplateCommand, "_install_repo")
|
||||
@patch("crewai.cli.remote_template.main.click.prompt", return_value="2")
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_list_templates_selects_and_installs(self, mock_get, mock_prompt, mock_install, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
with patch("crewai.cli.remote_template.main.console"):
|
||||
cmd.list_templates()
|
||||
|
||||
# Templates are sorted by name; index 1 (choice "2") = template_deep_research
|
||||
mock_install.assert_called_once_with("template_deep_research")
|
||||
|
||||
@patch.object(TemplateCommand, "_install_repo")
|
||||
@patch("crewai.cli.remote_template.main.click.prompt", return_value="q")
|
||||
@patch("crewai.cli.remote_template.main.httpx.get")
|
||||
def test_list_templates_quit(self, mock_get, mock_prompt, mock_install, cmd):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
mock_get.side_effect = [mock_response, mock_empty]
|
||||
|
||||
with patch("crewai.cli.remote_template.main.console"):
|
||||
cmd.list_templates()
|
||||
|
||||
mock_install.assert_not_called()
|
||||
@@ -367,7 +367,7 @@ def test_deploy_push(command, runner):
|
||||
result = runner.invoke(deploy_push, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=uuid, skip_validate=False)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.DeployCommand")
|
||||
@@ -376,7 +376,7 @@ def test_deploy_push_no_uuid(command, runner):
|
||||
result = runner.invoke(deploy_push)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=None)
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=None, skip_validate=False)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.DeployCommand")
|
||||
|
||||
@@ -161,7 +161,8 @@ def test_install_api_error(mock_get, capsys, tool_command):
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
@patch("crewai.cli.tools.main.git.Repository.__init__", return_value=None)
|
||||
def test_publish_when_not_in_sync(mock_init, mock_is_synced, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
|
||||
@@ -174,3 +174,51 @@ class TestEmitCallCompletedEventPassesUsage:
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallCompletedEvent)
|
||||
assert event.usage is None
|
||||
|
||||
class TestUsageMetricsNewFields:
|
||||
def test_add_usage_metrics_aggregates_reasoning_and_cache_creation(self):
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
metrics1 = UsageMetrics(
|
||||
total_tokens=100,
|
||||
prompt_tokens=60,
|
||||
completion_tokens=40,
|
||||
cached_prompt_tokens=10,
|
||||
reasoning_tokens=15,
|
||||
cache_creation_tokens=5,
|
||||
successful_requests=1,
|
||||
)
|
||||
metrics2 = UsageMetrics(
|
||||
total_tokens=200,
|
||||
prompt_tokens=120,
|
||||
completion_tokens=80,
|
||||
cached_prompt_tokens=20,
|
||||
reasoning_tokens=25,
|
||||
cache_creation_tokens=10,
|
||||
successful_requests=1,
|
||||
)
|
||||
|
||||
metrics1.add_usage_metrics(metrics2)
|
||||
|
||||
assert metrics1.total_tokens == 300
|
||||
assert metrics1.prompt_tokens == 180
|
||||
assert metrics1.completion_tokens == 120
|
||||
assert metrics1.cached_prompt_tokens == 30
|
||||
assert metrics1.reasoning_tokens == 40
|
||||
assert metrics1.cache_creation_tokens == 15
|
||||
assert metrics1.successful_requests == 2
|
||||
|
||||
def test_new_fields_default_to_zero(self):
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
metrics = UsageMetrics()
|
||||
assert metrics.reasoning_tokens == 0
|
||||
assert metrics.cache_creation_tokens == 0
|
||||
|
||||
def test_model_dump_includes_new_fields(self):
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
metrics = UsageMetrics(reasoning_tokens=10, cache_creation_tokens=5)
|
||||
dumped = metrics.model_dump()
|
||||
assert dumped["reasoning_tokens"] == 10
|
||||
assert dumped["cache_creation_tokens"] == 5
|
||||
|
||||
@@ -192,6 +192,38 @@ class TestToolHookDecorators:
|
||||
# Should still be 1 (hook didn't execute for read_file)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_before_tool_call_tool_filter_sanitizes_names(self):
|
||||
"""Tool filter should auto-sanitize names so users can pass BaseTool.name directly."""
|
||||
execution_log = []
|
||||
|
||||
# User passes the human-readable tool name (e.g. BaseTool.name)
|
||||
@before_tool_call(tools=["Delete File", "Execute Code"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
mock_tool = Mock()
|
||||
# Context uses the sanitized name (as set by the executor)
|
||||
context = ToolCallHookContext(
|
||||
tool_name="delete_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
assert execution_log == ["delete_file"]
|
||||
|
||||
# Non-matching tool still filtered out
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context2)
|
||||
assert execution_log == ["delete_file"]
|
||||
|
||||
def test_before_tool_call_with_combined_filters(self):
|
||||
"""Test that combined tool and agent filters work."""
|
||||
execution_log = []
|
||||
|
||||
@@ -1463,3 +1463,45 @@ def test_tool_search_saves_input_tokens():
|
||||
f"Expected tool_search ({usage_search.prompt_tokens}) to use fewer input tokens "
|
||||
f"than no search ({usage_no_search.prompt_tokens})"
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_cache_creation_tokens_extraction():
|
||||
"""Test that cache_creation_input_tokens are extracted from Anthropic responses."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cache_read_input_tokens=30,
|
||||
cache_creation_input_tokens=20,
|
||||
)
|
||||
mock_response.stop_reason = None
|
||||
mock_response.model = None
|
||||
|
||||
usage = llm._extract_anthropic_token_usage(mock_response)
|
||||
assert usage["input_tokens"] == 100
|
||||
assert usage["output_tokens"] == 50
|
||||
assert usage["total_tokens"] == 150
|
||||
assert usage["cached_prompt_tokens"] == 30
|
||||
assert usage["cache_creation_tokens"] == 20
|
||||
|
||||
|
||||
def test_anthropic_missing_cache_fields_default_to_zero():
|
||||
"""Test that missing cache fields default to zero."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=40,
|
||||
output_tokens=20,
|
||||
spec=["input_tokens", "output_tokens"],
|
||||
)
|
||||
mock_response.usage.cache_read_input_tokens = None
|
||||
mock_response.usage.cache_creation_input_tokens = None
|
||||
|
||||
usage = llm._extract_anthropic_token_usage(mock_response)
|
||||
assert usage["cached_prompt_tokens"] == 0
|
||||
assert usage["cache_creation_tokens"] == 0
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user