mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-15 13:18:09 +00:00
Compare commits
9 Commits
1.14.7
...
flow-itera
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02eeefe5ea | ||
|
|
d80719df81 | ||
|
|
6ad821b157 | ||
|
|
2444895ca4 | ||
|
|
bf291a7a55 | ||
|
|
64438cba37 | ||
|
|
887adafd2c | ||
|
|
d3fc0d31f8 | ||
|
|
373dca3d04 |
@@ -226,6 +226,48 @@ counter=2 message='Hello from first_method - updated by second_method'
|
||||
من خلال ضمان إعادة مخرجات الدالة الأخيرة وتوفير الوصول إلى الحالة، تجعل تدفقات CrewAI من السهل دمج نتائج سير عمل الذكاء الاصطناعي في التطبيقات أو الأنظمة الأكبر،
|
||||
مع الحفاظ على الوصول إلى الحالة طوال تنفيذ التدفق.
|
||||
|
||||
## مقاييس استخدام التدفق
|
||||
|
||||
بعد اكتمال تنفيذ التدفق، يمكنك الوصول إلى الخاصية `usage_metrics` لعرض إجمالي استخدام التوكنات عبر **كل استدعاء لنموذج اللغة** يتم خلال التشغيل — بما في ذلك الاستدعاءات من كل فريق (Crew) ينظمه التدفق، والاستدعاءات داخل أدوات الـ Agents، والاستدعاءات المباشرة لـ `LLM.call(...)` من دوال التدفق. هذا هو المكافئ على جانب الـ SDK للإجماليات المعروضة في واجهة CrewAI Enterprise.
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
|
||||
class UsageMetricsFlow(Flow):
|
||||
@start()
|
||||
def run_first_crew(self):
|
||||
self.state.first_result = FirstCrew().crew().kickoff()
|
||||
|
||||
@listen(run_first_crew)
|
||||
def call_llm_directly(self):
|
||||
# استدعاء مباشر لنموذج اللغة — يُحسب أيضًا ضمن flow.usage_metrics
|
||||
llm = LLM(model="openai/gpt-4o-mini")
|
||||
self.state.summary = llm.call("لخّص النقاط الرئيسية.")
|
||||
|
||||
@listen(call_llm_directly)
|
||||
def run_second_crew(self):
|
||||
self.state.second_result = SecondCrew().crew().kickoff()
|
||||
|
||||
flow = UsageMetricsFlow()
|
||||
flow.kickoff()
|
||||
|
||||
print(flow.usage_metrics)
|
||||
# UsageMetrics(total_tokens=8579, prompt_tokens=6210, completion_tokens=2369,
|
||||
# cached_prompt_tokens=0, reasoning_tokens=0,
|
||||
# cache_creation_tokens=0, successful_requests=5)
|
||||
```
|
||||
|
||||
<Note>
|
||||
`flow.usage_metrics` **ليست** نفس `flow.kickoff().token_usage`. هذه الأخيرة
|
||||
ترجع فقط `CrewOutput.token_usage` لـ **آخر** دالة `@listen` أعادت
|
||||
`CrewOutput`، مما يعني أنها تعكس فقط الفريق الأخير وتتجاهل الفرق السابقة
|
||||
وكذلك أي استدعاءات مباشرة لـ `LLM.call(...)`. استخدم `flow.usage_metrics`
|
||||
كلما احتجت إلى الإجمالي **الكامل** للتوكنات لتنفيذ التدفق.
|
||||
</Note>
|
||||
|
||||
كل حقل في [`UsageMetrics`](https://github.com/crewAIInc/crewAI/blob/main/lib/crewai/src/crewai/types/usage_metrics.py) المُعاد هو مجموع جميع استدعاءات نموذج اللغة التي حدثت خلال استدعاء واحد لـ `flow.kickoff()`. تتم إعادة تعيين العدادات عند الاستدعاء التالي لـ `kickoff()` (وفي كل تكرار من `kickoff_for_each`)، لذلك لن تتكرر العدّات عبر التشغيلات المتتالية. يمكن قراءة هذه الخاصية بأمان في أي وقت بعد اكتمال `kickoff()`؛ قراءتها أثناء التنفيذ تُرجع المجموع الجزئي المتراكم حتى تلك اللحظة.
|
||||
|
||||
## إدارة حالة التدفق
|
||||
|
||||
إدارة الحالة بفعالية أمر بالغ الأهمية لبناء سير عمل ذكاء اصطناعي موثوق وقابل للصيانة. توفر تدفقات CrewAI آليات قوية لإدارة الحالة غير المهيكلة والمهيكلة،
|
||||
|
||||
@@ -226,6 +226,49 @@ After the Flow has run, you can access the final state to see the updates made b
|
||||
By ensuring that the final method's output is returned and providing access to the state, CrewAI Flows make it easy to integrate the results of your AI workflows into larger applications or systems,
|
||||
while also maintaining and accessing the state throughout the Flow's execution.
|
||||
|
||||
## Flow Usage Metrics
|
||||
|
||||
After a Flow execution completes, you can access the `usage_metrics` property to view aggregated token usage across **every LLM call** made during the run — including calls from every Crew the Flow orchestrated, calls inside Agent tools, and bare `LLM.call(...)` invocations from Flow methods. This is the SDK-side equivalent of the totals shown in the CrewAI Enterprise UI.
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
|
||||
class UsageMetricsFlow(Flow):
|
||||
@start()
|
||||
def run_first_crew(self):
|
||||
self.state.first_result = FirstCrew().crew().kickoff()
|
||||
|
||||
@listen(run_first_crew)
|
||||
def call_llm_directly(self):
|
||||
# Bare LLM call — still counted by flow.usage_metrics
|
||||
llm = LLM(model="openai/gpt-4o-mini")
|
||||
self.state.summary = llm.call("Summarize the key takeaways.")
|
||||
|
||||
@listen(call_llm_directly)
|
||||
def run_second_crew(self):
|
||||
self.state.second_result = SecondCrew().crew().kickoff()
|
||||
|
||||
flow = UsageMetricsFlow()
|
||||
flow.kickoff()
|
||||
|
||||
print(flow.usage_metrics)
|
||||
# UsageMetrics(total_tokens=8579, prompt_tokens=6210, completion_tokens=2369,
|
||||
# cached_prompt_tokens=0, reasoning_tokens=0,
|
||||
# cache_creation_tokens=0, successful_requests=5)
|
||||
```
|
||||
|
||||
<Note>
|
||||
`flow.usage_metrics` is **not** the same as `flow.kickoff().token_usage`. The
|
||||
latter returns the `CrewOutput.token_usage` of the **last** `@listen` method
|
||||
that returned a `CrewOutput`, which means it only reflects the final Crew and
|
||||
ignores prior Crews and bare `LLM.call(...)` invocations entirely. Use
|
||||
`flow.usage_metrics` whenever you need the **full** token rollup for the Flow
|
||||
execution.
|
||||
</Note>
|
||||
|
||||
Each entry in the returned [`UsageMetrics`](https://github.com/crewAIInc/crewAI/blob/main/lib/crewai/src/crewai/types/usage_metrics.py) is the sum across all LLM calls made within a single `flow.kickoff()` invocation. Counters reset on the next `kickoff()` call (or on each iteration of `kickoff_for_each`), so successive runs don't double-count. The property is safe to read at any point after `kickoff()` completes; reading it during execution returns the partial total accumulated so far.
|
||||
|
||||
## Flow State Management
|
||||
|
||||
Managing state effectively is crucial for building reliable and maintainable AI workflows. CrewAI Flows provides robust mechanisms for both unstructured and structured state management,
|
||||
|
||||
@@ -221,6 +221,48 @@ Flow가 실행된 후, 이러한 메소드들에 의해 수행된 업데이트
|
||||
최종 메소드의 출력이 반환되고 상태에 접근할 수 있도록 함으로써, CrewAI Flow는 AI 워크플로우의 결과를 더 큰 애플리케이션이나 시스템에 쉽게 통합할 수 있게 하며,
|
||||
Flow 실행 과정 전반에 걸쳐 상태를 유지하고 접근하면서도 이를 용이하게 만듭니다.
|
||||
|
||||
## 플로우 사용 메트릭
|
||||
|
||||
Flow 실행이 완료된 후, `usage_metrics` 속성에 접근하여 실행 동안 발생한 **모든 LLM 호출**의 토큰 사용량 집계를 확인할 수 있습니다. 여기에는 Flow가 오케스트레이션한 모든 Crew의 호출, Agent의 도구 내부에서 발생한 호출, 그리고 Flow 메서드에서 직접 호출한 `LLM.call(...)`이 모두 포함됩니다. 이는 CrewAI Enterprise UI에 표시되는 총량과 동등한 SDK 측 값입니다.
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
|
||||
class UsageMetricsFlow(Flow):
|
||||
@start()
|
||||
def run_first_crew(self):
|
||||
self.state.first_result = FirstCrew().crew().kickoff()
|
||||
|
||||
@listen(run_first_crew)
|
||||
def call_llm_directly(self):
|
||||
# 직접 LLM 호출 — flow.usage_metrics에서도 집계됩니다
|
||||
llm = LLM(model="openai/gpt-4o-mini")
|
||||
self.state.summary = llm.call("핵심 내용을 요약해 주세요.")
|
||||
|
||||
@listen(call_llm_directly)
|
||||
def run_second_crew(self):
|
||||
self.state.second_result = SecondCrew().crew().kickoff()
|
||||
|
||||
flow = UsageMetricsFlow()
|
||||
flow.kickoff()
|
||||
|
||||
print(flow.usage_metrics)
|
||||
# UsageMetrics(total_tokens=8579, prompt_tokens=6210, completion_tokens=2369,
|
||||
# cached_prompt_tokens=0, reasoning_tokens=0,
|
||||
# cache_creation_tokens=0, successful_requests=5)
|
||||
```
|
||||
|
||||
<Note>
|
||||
`flow.usage_metrics`는 `flow.kickoff().token_usage`와 **동일하지 않습니다**.
|
||||
후자는 `CrewOutput`을 반환한 **마지막** `@listen` 메서드의
|
||||
`CrewOutput.token_usage`만 반환하므로, 이전에 실행된 Crew들과 Flow 메서드에서
|
||||
직접 호출한 `LLM.call(...)`은 전혀 포함되지 않습니다. Flow 실행에 대한
|
||||
**전체** 토큰 집계가 필요할 때는 항상 `flow.usage_metrics`를 사용하십시오.
|
||||
</Note>
|
||||
|
||||
반환되는 [`UsageMetrics`](https://github.com/crewAIInc/crewAI/blob/main/lib/crewai/src/crewai/types/usage_metrics.py)의 각 항목은 단일 `flow.kickoff()` 실행 동안 발생한 모든 LLM 호출의 합계입니다. 다음 `kickoff()` 호출(및 `kickoff_for_each`의 각 반복)에서 카운터가 초기화되므로 연속 실행이 이중으로 집계되지 않습니다. 이 속성은 `kickoff()` 완료 후 언제든지 안전하게 읽을 수 있으며, 실행 중에 읽으면 그 시점까지 누적된 부분 합계를 반환합니다.
|
||||
|
||||
## 플로우 상태 관리
|
||||
|
||||
상태를 효과적으로 관리하는 것은 신뢰할 수 있고 유지 보수가 용이한 AI 워크플로를 구축하는 데 매우 중요합니다. CrewAI 플로우는 비정형 및 정형 상태 관리를 위한 강력한 메커니즘을 제공하여, 개발자가 자신의 애플리케이션에 가장 적합한 접근 방식을 선택할 수 있도록 합니다.
|
||||
|
||||
@@ -219,6 +219,49 @@ Após o término da execução, é possível acessar o estado final e observar a
|
||||
Ao garantir que a saída do método final seja retornada e oferecer acesso ao estado, o CrewAI Flows facilita a integração dos resultados dos seus workflows de IA em aplicações maiores,
|
||||
além de permitir o gerenciamento e o acesso ao estado durante toda a execução do Flow.
|
||||
|
||||
## Métricas de Uso do Flow
|
||||
|
||||
Após a execução de um Flow, você pode acessar a propriedade `usage_metrics` para visualizar o consumo agregado de tokens em **todas as chamadas de LLM** realizadas durante a execução — incluindo chamadas das Crews orquestradas pelo Flow, chamadas dentro de tools de Agents, e invocações diretas de `LLM.call(...)` feitas a partir de métodos do Flow. Esse é o equivalente, do lado do SDK, ao total exibido na interface do CrewAI Enterprise.
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
|
||||
class UsageMetricsFlow(Flow):
|
||||
@start()
|
||||
def run_first_crew(self):
|
||||
self.state.first_result = FirstCrew().crew().kickoff()
|
||||
|
||||
@listen(run_first_crew)
|
||||
def call_llm_directly(self):
|
||||
# Chamada direta de LLM — também contabilizada por flow.usage_metrics
|
||||
llm = LLM(model="openai/gpt-4o-mini")
|
||||
self.state.summary = llm.call("Resuma os principais pontos.")
|
||||
|
||||
@listen(call_llm_directly)
|
||||
def run_second_crew(self):
|
||||
self.state.second_result = SecondCrew().crew().kickoff()
|
||||
|
||||
flow = UsageMetricsFlow()
|
||||
flow.kickoff()
|
||||
|
||||
print(flow.usage_metrics)
|
||||
# UsageMetrics(total_tokens=8579, prompt_tokens=6210, completion_tokens=2369,
|
||||
# cached_prompt_tokens=0, reasoning_tokens=0,
|
||||
# cache_creation_tokens=0, successful_requests=5)
|
||||
```
|
||||
|
||||
<Note>
|
||||
`flow.usage_metrics` **não** é o mesmo que `flow.kickoff().token_usage`. Este
|
||||
último retorna apenas o `CrewOutput.token_usage` do **último** método
|
||||
`@listen` que retornou um `CrewOutput`, ou seja, reflete somente a Crew
|
||||
final e ignora completamente as Crews anteriores e quaisquer chamadas
|
||||
diretas de `LLM.call(...)`. Use `flow.usage_metrics` sempre que precisar do
|
||||
rollup **completo** de tokens da execução do Flow.
|
||||
</Note>
|
||||
|
||||
Cada campo do [`UsageMetrics`](https://github.com/crewAIInc/crewAI/blob/main/lib/crewai/src/crewai/types/usage_metrics.py) retornado representa a soma de todas as chamadas de LLM feitas em uma única invocação de `flow.kickoff()`. Os contadores são resetados a cada novo `kickoff()` (e em cada iteração de `kickoff_for_each`), de modo que execuções sucessivas não duplicam o total. A propriedade é segura para ser lida em qualquer momento após o `kickoff()`; lê-la durante a execução retorna o total parcial acumulado até aquele instante.
|
||||
|
||||
## Gerenciamento de Estado em Flows
|
||||
|
||||
Gerenciar o estado de forma eficaz é fundamental para construir fluxos de trabalho de IA confiáveis e de fácil manutenção. O CrewAI Flows oferece mecanismos robustos para o gerenciamento de estado tanto não estruturado quanto estruturado,
|
||||
|
||||
@@ -26,6 +26,7 @@ from crewai_cli.remote_template.main import TemplateCommand
|
||||
from crewai_cli.replay_from_task import replay_task_command
|
||||
from crewai_cli.reset_memories_command import reset_memories_command
|
||||
from crewai_cli.run_crew import run_crew
|
||||
from crewai_cli.run_flow_definition import run_flow_definition
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.task_outputs import load_task_outputs
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
@@ -398,8 +399,36 @@ def install(context: click.Context) -> None:
|
||||
"CREWAI_TRAINED_AGENTS_FILE."
|
||||
),
|
||||
)
|
||||
def run(trained_agents_file: str | None) -> None:
|
||||
"""Run the Crew."""
|
||||
@click.option(
|
||||
"--definition",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Experimental: path to a Flow Definition YAML/JSON file, "
|
||||
"or an inline YAML/JSON string."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--inputs",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Experimental: JSON object passed to flow.kickoff(), e.g. \'{"topic":"AI"}\'.',
|
||||
)
|
||||
def run(
|
||||
trained_agents_file: str | None, definition: str | None, inputs: str | None
|
||||
) -> None:
|
||||
"""Run the Crew or Flow."""
|
||||
if inputs is not None and definition is None:
|
||||
raise click.UsageError("--inputs requires --definition")
|
||||
|
||||
if definition is not None:
|
||||
click.secho(
|
||||
"Warning: `crewai run --definition` is experimental and may change without notice.",
|
||||
fg="yellow",
|
||||
)
|
||||
run_flow_definition(definition=definition, inputs=inputs)
|
||||
return
|
||||
|
||||
run_crew(trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
|
||||
113
lib/cli/src/crewai_cli/run_flow_definition.py
Normal file
113
lib/cli/src/crewai_cli/run_flow_definition.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def run_flow_definition(definition: str, inputs: str | None = None) -> None:
|
||||
"""Run a flow from a Flow Definition YAML/JSON string or file path."""
|
||||
try:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.flow.flow_definition import FlowDefinition
|
||||
except ImportError as exc:
|
||||
click.echo(
|
||||
"Running flows from definitions requires the full crewai package.",
|
||||
err=True,
|
||||
)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
parsed_inputs = _parse_inputs(inputs)
|
||||
definition_source = _read_definition_source(definition)
|
||||
|
||||
try:
|
||||
flow_definition = _parse_flow_definition(FlowDefinition, definition_source)
|
||||
flow = Flow.from_definition(flow_definition)
|
||||
result = flow.kickoff(inputs=parsed_inputs)
|
||||
except Exception as exc:
|
||||
click.echo(
|
||||
f"An error occurred while running the flow definition: {exc}", err=True
|
||||
)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
click.echo(_format_result(result))
|
||||
|
||||
|
||||
def _parse_inputs(inputs: str | None) -> dict[str, Any] | None:
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = json.loads(inputs)
|
||||
except json.JSONDecodeError as exc:
|
||||
click.echo(f"Invalid --inputs JSON: {exc}", err=True)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
click.echo("Invalid --inputs JSON: expected an object.", err=True)
|
||||
raise SystemExit(1)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _read_definition_source(definition: str) -> str:
|
||||
path = Path(definition).expanduser()
|
||||
try:
|
||||
is_file = path.is_file()
|
||||
except OSError as exc:
|
||||
if _looks_like_inline_definition(definition):
|
||||
return definition
|
||||
click.echo(f"Invalid --definition path: {definition} ({exc})", err=True)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
if is_file:
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeError) as exc:
|
||||
click.echo(
|
||||
f"Unable to read --definition path {path}: {exc}",
|
||||
err=True,
|
||||
)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
try:
|
||||
if path.exists():
|
||||
click.echo(
|
||||
f"Invalid --definition path: {definition} is not a file.", err=True
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except OSError as exc:
|
||||
click.echo(f"Invalid --definition path: {definition} ({exc})", err=True)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
return definition
|
||||
|
||||
|
||||
def _looks_like_inline_definition(definition: str) -> bool:
|
||||
stripped = definition.lstrip()
|
||||
return "\n" in definition or stripped.startswith(("{", "---")) or ":" in stripped
|
||||
|
||||
|
||||
def _parse_flow_definition(flow_definition_cls: type[Any], source: str) -> Any:
|
||||
if _looks_like_json(source):
|
||||
return flow_definition_cls.from_json(source)
|
||||
|
||||
return flow_definition_cls.from_yaml(source)
|
||||
|
||||
|
||||
def _looks_like_json(source: str) -> bool:
|
||||
stripped = source.lstrip()
|
||||
return stripped.startswith("{")
|
||||
|
||||
|
||||
def _format_result(result: Any) -> str:
|
||||
raw_result = getattr(result, "raw", result)
|
||||
if isinstance(raw_result, str):
|
||||
return raw_result
|
||||
|
||||
try:
|
||||
return json.dumps(raw_result, default=str)
|
||||
except TypeError:
|
||||
return str(raw_result)
|
||||
@@ -13,6 +13,7 @@ from crewai_cli.cli import (
|
||||
flow_add_crew,
|
||||
login,
|
||||
reset_memories,
|
||||
run,
|
||||
test,
|
||||
train,
|
||||
version,
|
||||
@@ -119,6 +120,43 @@ def test_test_invalid_string_iterations(evaluate_crew, runner):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.run_crew")
|
||||
def test_run_uses_project_runner_by_default(run_crew, runner):
|
||||
result = runner.invoke(run)
|
||||
|
||||
assert result.exit_code == 0
|
||||
run_crew.assert_called_once_with(trained_agents_file=None)
|
||||
assert "experimental" not in result.output.lower()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.run_flow_definition")
|
||||
def test_run_with_definition_uses_definition_runner(run_flow_definition, runner):
|
||||
result = runner.invoke(
|
||||
run,
|
||||
["--definition", "flow.yaml", "--inputs", '{"topic":"AI"}'],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert (
|
||||
"Warning: `crewai run --definition` is experimental and may change without notice."
|
||||
in result.output
|
||||
)
|
||||
run_flow_definition.assert_called_once_with(
|
||||
definition="flow.yaml", inputs='{"topic":"AI"}'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.run_crew")
|
||||
@mock.patch("crewai_cli.cli.run_flow_definition")
|
||||
def test_run_rejects_inputs_without_definition(run_flow_definition, run_crew, runner):
|
||||
result = runner.invoke(run, ["--inputs", '{"topic":"AI"}'])
|
||||
|
||||
assert result.exit_code == 2
|
||||
assert "Error: --inputs requires --definition" in result.output
|
||||
run_flow_definition.assert_not_called()
|
||||
run_crew.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.AuthenticationCommand")
|
||||
def test_login(command, runner):
|
||||
mock_auth = command.return_value
|
||||
|
||||
156
lib/cli/tests/test_run_flow_definition.py
Normal file
156
lib/cli/tests/test_run_flow_definition.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from crewai_cli.run_flow_definition import run_flow_definition
|
||||
|
||||
|
||||
class _FakeFlow:
|
||||
def __init__(self, definition):
|
||||
self.definition = definition
|
||||
|
||||
def kickoff(self, inputs=None):
|
||||
return {
|
||||
"flow": self.definition["name"],
|
||||
"inputs": inputs or {},
|
||||
}
|
||||
|
||||
|
||||
class _FakeFlowFactory:
|
||||
@classmethod
|
||||
def from_definition(cls, definition):
|
||||
return _FakeFlow(definition)
|
||||
|
||||
|
||||
class _FakeFlowDefinition:
|
||||
@classmethod
|
||||
def from_yaml(cls, source):
|
||||
return yaml.safe_load(source)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, source):
|
||||
return json.loads(source)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_flow_runtime(monkeypatch):
|
||||
crewai_module = types.ModuleType("crewai")
|
||||
flow_package = types.ModuleType("crewai.flow")
|
||||
flow_module = types.ModuleType("crewai.flow.flow")
|
||||
flow_definition_module = types.ModuleType("crewai.flow.flow_definition")
|
||||
|
||||
flow_module.Flow = _FakeFlowFactory
|
||||
flow_definition_module.FlowDefinition = _FakeFlowDefinition
|
||||
|
||||
monkeypatch.setitem(sys.modules, "crewai", crewai_module)
|
||||
monkeypatch.setitem(sys.modules, "crewai.flow", flow_package)
|
||||
monkeypatch.setitem(sys.modules, "crewai.flow.flow", flow_module)
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "crewai.flow.flow_definition", flow_definition_module
|
||||
)
|
||||
|
||||
|
||||
def _captured_json(capsys):
|
||||
return json.loads(capsys.readouterr().out)
|
||||
|
||||
|
||||
def test_run_flow_definition_reads_definition_file(
|
||||
tmp_path, capsys, fake_flow_runtime
|
||||
):
|
||||
definition_path = tmp_path / "flow.yaml"
|
||||
definition_path.write_text("schema: crewai.flow/v1\nname: TestFlow\n")
|
||||
|
||||
run_flow_definition(str(definition_path), '{"topic":"AI"}')
|
||||
|
||||
assert _captured_json(capsys) == {
|
||||
"flow": "TestFlow",
|
||||
"inputs": {"topic": "AI"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("definition_source", "expected_flow_name"),
|
||||
[
|
||||
pytest.param(
|
||||
"schema: crewai.flow/v1\nname: InlineFlow\n",
|
||||
"InlineFlow",
|
||||
id="inline-yaml",
|
||||
),
|
||||
pytest.param(
|
||||
'{"schema":"crewai.flow/v1","name":"InlineJsonFlow"}',
|
||||
"InlineJsonFlow",
|
||||
id="inline-json",
|
||||
),
|
||||
pytest.param(
|
||||
'{"schema":"crewai.flow/v1","name":"' + ("JsonFlow" * 500) + '"}',
|
||||
"JsonFlow" * 500,
|
||||
id="large-inline-json",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_flow_definition_accepts_inline_definitions(
|
||||
definition_source, expected_flow_name, capsys, fake_flow_runtime
|
||||
):
|
||||
run_flow_definition(definition_source)
|
||||
|
||||
assert _captured_json(capsys) == {"flow": expected_flow_name, "inputs": {}}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "definition_source", "expected_flow_name"),
|
||||
[
|
||||
pytest.param(
|
||||
"flow.yaml",
|
||||
"schema: crewai.flow/v1\nname: YamlFileFlow\n",
|
||||
"YamlFileFlow",
|
||||
id="yaml-file",
|
||||
),
|
||||
pytest.param(
|
||||
"flow.json",
|
||||
'{"schema":"crewai.flow/v1","name":"JsonFlow"}',
|
||||
"JsonFlow",
|
||||
id="json-file",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_flow_definition_accepts_definition_files(
|
||||
filename, definition_source, expected_flow_name, tmp_path, capsys, fake_flow_runtime
|
||||
):
|
||||
definition_path = tmp_path / filename
|
||||
definition_path.write_text(definition_source)
|
||||
|
||||
run_flow_definition(str(definition_path))
|
||||
|
||||
assert _captured_json(capsys) == {"flow": expected_flow_name, "inputs": {}}
|
||||
|
||||
|
||||
def test_run_flow_definition_rejects_non_object_inputs(fake_flow_runtime, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
run_flow_definition("name: TestFlow", '["not", "an", "object"]')
|
||||
|
||||
assert "Invalid --inputs JSON: expected an object." in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_run_flow_definition_reports_unreadable_file(
|
||||
monkeypatch, tmp_path, capsys, fake_flow_runtime
|
||||
):
|
||||
definition_path = tmp_path / "flow.yaml"
|
||||
definition_path.write_text("schema: crewai.flow/v1\nname: TestFlow\n")
|
||||
|
||||
def raise_permission_error(self, *args, **kwargs):
|
||||
raise PermissionError("no access")
|
||||
|
||||
monkeypatch.setattr("pathlib.Path.read_text", raise_permission_error)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
run_flow_definition(str(definition_path))
|
||||
|
||||
err = capsys.readouterr().err
|
||||
assert "Unable to read --definition path" in err
|
||||
assert str(definition_path) in err
|
||||
assert "no access" in err
|
||||
@@ -13,8 +13,8 @@ from crewai_core import (
|
||||
user_data,
|
||||
version,
|
||||
)
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
import pytest
|
||||
|
||||
|
||||
def test_version_returns_string() -> None:
|
||||
|
||||
@@ -63,7 +63,7 @@ spider-client = [
|
||||
"spider-client>=0.1.25",
|
||||
]
|
||||
scrapegraph-py = [
|
||||
"scrapegraph-py>=1.9.0",
|
||||
"scrapegraph-py>=1.9.0,<2",
|
||||
]
|
||||
linkup-sdk = [
|
||||
"linkup-sdk>=0.2.2",
|
||||
|
||||
@@ -22,6 +22,31 @@ logger = logging.getLogger(__name__)
|
||||
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
|
||||
|
||||
|
||||
def format_path_for_display(path: str, base_dir: str | None = None) -> str:
|
||||
"""Return a path label that does not expose absolute directory prefixes."""
|
||||
if base_dir is None:
|
||||
base_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
resolved_base = os.path.realpath(base_dir)
|
||||
resolved_path = os.path.realpath(
|
||||
os.path.join(resolved_base, path) if not os.path.isabs(path) else path
|
||||
)
|
||||
if os.path.commonpath([resolved_base, resolved_path]) == resolved_base:
|
||||
return os.path.relpath(resolved_path, resolved_base)
|
||||
except (OSError, ValueError) as exc:
|
||||
logger.debug("Falling back to basename for display path formatting: %s", exc)
|
||||
|
||||
return os.path.basename(os.path.realpath(path)) or "[redacted path]"
|
||||
|
||||
|
||||
def format_error_for_display(error: Exception) -> str:
|
||||
"""Return exception details without OS-added absolute path context."""
|
||||
if isinstance(error, OSError):
|
||||
return error.strerror or error.__class__.__name__
|
||||
return str(error)
|
||||
|
||||
|
||||
def _is_escape_hatch_enabled() -> bool:
|
||||
"""Check if the unsafe paths escape hatch is enabled."""
|
||||
return os.environ.get(_UNSAFE_PATHS_ENV, "").lower() in ("true", "1", "yes")
|
||||
@@ -66,8 +91,8 @@ def validate_file_path(path: str, base_dir: str | None = None) -> str:
|
||||
prefix = resolved_base if resolved_base.endswith(os.sep) else resolved_base + os.sep
|
||||
if not resolved_path.startswith(prefix) and resolved_path != resolved_base:
|
||||
raise ValueError(
|
||||
f"Path '{path}' resolves to '{resolved_path}' which is outside "
|
||||
f"the allowed directory '{resolved_base}'. "
|
||||
f"Path '{format_path_for_display(resolved_path, resolved_base)}' is "
|
||||
f"outside the allowed directory. "
|
||||
f"Set {_UNSAFE_PATHS_ENV}=true to bypass this check."
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@ from typing import Any
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.security.safe_path import validate_file_path
|
||||
from crewai_tools.security.safe_path import (
|
||||
format_error_for_display,
|
||||
format_path_for_display,
|
||||
validate_file_path,
|
||||
)
|
||||
|
||||
|
||||
class FileReadToolSchema(BaseModel):
|
||||
@@ -58,8 +62,9 @@ class FileReadTool(BaseTool):
|
||||
**kwargs: Additional keyword arguments passed to BaseTool.
|
||||
"""
|
||||
if file_path is not None:
|
||||
display_path = format_path_for_display(file_path)
|
||||
kwargs["description"] = (
|
||||
f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file. You can also specify 'start_line' and 'line_count' to read specific parts of the file."
|
||||
f"A tool that reads file content. The default file is {display_path}, but you can provide a different 'file_path' parameter to read another file. You can also specify 'start_line' and 'line_count' to read specific parts of the file."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -78,7 +83,12 @@ class FileReadTool(BaseTool):
|
||||
if file_path is None:
|
||||
return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
|
||||
|
||||
file_path = validate_file_path(file_path)
|
||||
try:
|
||||
file_path = validate_file_path(file_path)
|
||||
except ValueError as e:
|
||||
return f"Error: Invalid file path: {e!s}"
|
||||
|
||||
display_path = format_path_for_display(file_path)
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
if start_line == 1 and line_count is None:
|
||||
@@ -98,8 +108,11 @@ class FileReadTool(BaseTool):
|
||||
|
||||
return "".join(selected_lines)
|
||||
except FileNotFoundError:
|
||||
return f"Error: File not found at path: {file_path}"
|
||||
return f"Error: File not found at path: {display_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when trying to read file: {file_path}"
|
||||
return f"Error: Permission denied when trying to read file: {display_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Failed to read file {file_path}. {e!s}"
|
||||
return (
|
||||
f"Error: Failed to read file {display_path}. "
|
||||
f"{format_error_for_display(e)}"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,11 @@ from typing import Any
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.security.safe_path import (
|
||||
format_error_for_display,
|
||||
format_path_for_display,
|
||||
)
|
||||
|
||||
|
||||
def strtobool(val: str | bool) -> bool:
|
||||
if isinstance(val, bool):
|
||||
@@ -44,6 +49,9 @@ class FileWriterTool(BaseTool):
|
||||
# itself, since that is not a valid file target.
|
||||
real_directory = Path(directory).resolve()
|
||||
real_filepath = Path(filepath).resolve()
|
||||
display_filepath = format_path_for_display(
|
||||
str(real_filepath), str(real_directory)
|
||||
)
|
||||
if (
|
||||
not real_filepath.is_relative_to(real_directory)
|
||||
or real_filepath == real_directory
|
||||
@@ -56,15 +64,18 @@ class FileWriterTool(BaseTool):
|
||||
kwargs["overwrite"] = strtobool(kwargs["overwrite"])
|
||||
|
||||
if os.path.exists(real_filepath) and not kwargs["overwrite"]:
|
||||
return f"File {real_filepath} already exists and overwrite option was not passed."
|
||||
return f"File {display_filepath} already exists and overwrite option was not passed."
|
||||
|
||||
mode = "w" if kwargs["overwrite"] else "x"
|
||||
with open(real_filepath, mode) as file:
|
||||
file.write(kwargs["content"])
|
||||
return f"Content successfully written to {real_filepath}"
|
||||
return f"Content successfully written to {display_filepath}"
|
||||
except FileExistsError:
|
||||
return f"File {real_filepath} already exists and overwrite option was not passed."
|
||||
return f"File {display_filepath} already exists and overwrite option was not passed."
|
||||
except KeyError as e:
|
||||
return f"An error occurred while accessing key: {e!s}"
|
||||
except Exception as e:
|
||||
return f"An error occurred while writing to the file: {e!s}"
|
||||
return (
|
||||
"An error occurred while writing to the file: "
|
||||
f"{format_error_for_display(e)}"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from crewai_tools import FileReadTool
|
||||
@@ -6,21 +5,16 @@ from crewai_tools import FileReadTool
|
||||
|
||||
def test_file_read_tool_constructor():
|
||||
"""Test FileReadTool initialization with file_path."""
|
||||
test_file = "/tmp/test_file.txt"
|
||||
test_content = "Hello, World!"
|
||||
with open(test_file, "w") as f:
|
||||
f.write(test_content)
|
||||
test_file = "test_file.txt"
|
||||
|
||||
tool = FileReadTool(file_path=test_file)
|
||||
assert tool.file_path == test_file
|
||||
assert "test_file.txt" in tool.description
|
||||
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_run():
|
||||
"""Test FileReadTool _run method with file_path at runtime."""
|
||||
test_file = "/tmp/test_file.txt"
|
||||
test_file = "test_file.txt"
|
||||
test_content = "Hello, World!"
|
||||
|
||||
# Use mock_open to mock file operations
|
||||
@@ -36,18 +30,18 @@ def test_file_read_tool_error_handling():
|
||||
result = tool._run()
|
||||
assert "Error: No file path provided" in result
|
||||
|
||||
result = tool._run(file_path="/nonexistent/file.txt")
|
||||
result = tool._run(file_path="nonexistent/file.txt")
|
||||
assert "Error: File not found at path:" in result
|
||||
|
||||
with patch("builtins.open", side_effect=PermissionError()):
|
||||
result = tool._run(file_path="/tmp/no_permission.txt")
|
||||
result = tool._run(file_path="no_permission.txt")
|
||||
assert "Error: Permission denied" in result
|
||||
|
||||
|
||||
def test_file_read_tool_constructor_and_run():
|
||||
"""Test FileReadTool using both constructor and runtime file paths."""
|
||||
test_file1 = "/tmp/test1.txt"
|
||||
test_file2 = "/tmp/test2.txt"
|
||||
test_file1 = "test1.txt"
|
||||
test_file2 = "test2.txt"
|
||||
content1 = "File 1 content"
|
||||
content2 = "File 2 content"
|
||||
|
||||
@@ -64,7 +58,7 @@ def test_file_read_tool_constructor_and_run():
|
||||
|
||||
def test_file_read_tool_chunk_reading():
|
||||
"""Test FileReadTool reading specific chunks of a file."""
|
||||
test_file = "/tmp/multiline_test.txt"
|
||||
test_file = "multiline_test.txt"
|
||||
lines = [
|
||||
"Line 1\n",
|
||||
"Line 2\n",
|
||||
@@ -104,7 +98,7 @@ def test_file_read_tool_chunk_reading():
|
||||
|
||||
def test_file_read_tool_chunk_error_handling():
|
||||
"""Test error handling for chunk reading."""
|
||||
test_file = "/tmp/short_test.txt"
|
||||
test_file = "short_test.txt"
|
||||
lines = ["Line 1\n", "Line 2\n", "Line 3\n"]
|
||||
file_content = "".join(lines)
|
||||
|
||||
@@ -122,7 +116,7 @@ def test_file_read_tool_chunk_error_handling():
|
||||
|
||||
def test_file_read_tool_zero_or_negative_start_line():
|
||||
"""Test that start_line values of 0 or negative read from the start of the file."""
|
||||
test_file = "/tmp/negative_test.txt"
|
||||
test_file = "negative_test.txt"
|
||||
lines = ["Line 1\n", "Line 2\n", "Line 3\n", "Line 4\n", "Line 5\n"]
|
||||
file_content = "".join(lines)
|
||||
|
||||
@@ -150,3 +144,45 @@ def test_file_read_tool_zero_or_negative_start_line():
|
||||
result = tool._run(file_path=test_file, start_line=-10, line_count=2)
|
||||
expected = "".join(lines[0:2]) # Should read first 2 lines
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_file_read_tool_error_messages_do_not_disclose_absolute_paths(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""FileReadTool should redact absolute prefixes from user-visible errors."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tool = FileReadTool()
|
||||
target = tmp_path / "secret.txt"
|
||||
|
||||
result = tool._run(file_path=str(target))
|
||||
assert "secret.txt" in result
|
||||
assert str(tmp_path) not in result
|
||||
|
||||
target.touch()
|
||||
with patch("builtins.open", side_effect=PermissionError()):
|
||||
result = tool._run(file_path=str(target))
|
||||
assert "secret.txt" in result
|
||||
assert str(tmp_path) not in result
|
||||
|
||||
with patch(
|
||||
"builtins.open",
|
||||
side_effect=OSError(5, "Input/output error", str(target)),
|
||||
):
|
||||
result = tool._run(file_path=str(target))
|
||||
assert "secret.txt" in result
|
||||
assert str(tmp_path) not in result
|
||||
|
||||
|
||||
def test_file_read_tool_invalid_path_error_does_not_disclose_workspace(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""Validation errors should not echo the resolved workspace path."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
outside = tmp_path.parent / "outside.txt"
|
||||
|
||||
result = FileReadTool()._run(file_path=str(outside))
|
||||
|
||||
assert "Invalid file path" in result
|
||||
assert "outside.txt" in result
|
||||
assert str(tmp_path) not in result
|
||||
assert str(tmp_path.parent) not in result
|
||||
|
||||
@@ -47,6 +47,8 @@ def test_basic_file_write(tool, temp_env):
|
||||
assert os.path.exists(path)
|
||||
assert read_file(path) == temp_env["test_content"]
|
||||
assert "successfully written" in result
|
||||
assert temp_env["test_file"] in result
|
||||
assert temp_env["temp_dir"] not in result
|
||||
|
||||
|
||||
def test_directory_creation(tool, temp_env):
|
||||
@@ -62,6 +64,8 @@ def test_directory_creation(tool, temp_env):
|
||||
assert os.path.exists(new_dir)
|
||||
assert os.path.exists(path)
|
||||
assert "successfully written" in result
|
||||
assert temp_env["test_file"] in result
|
||||
assert new_dir not in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -134,6 +138,8 @@ def test_file_exists_error_handling(tool, temp_env, overwrite):
|
||||
)
|
||||
|
||||
assert "already exists and overwrite option was not passed" in result
|
||||
assert temp_env["test_file"] in result
|
||||
assert temp_env["temp_dir"] not in result
|
||||
assert read_file(path) == "Pre-existing content"
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from crewai_tools.security.safe_path import (
|
||||
format_path_for_display,
|
||||
validate_directory_path,
|
||||
validate_file_path,
|
||||
validate_url,
|
||||
@@ -66,6 +67,37 @@ class TestValidateFilePath:
|
||||
result = validate_file_path("/etc/passwd", str(tmp_path))
|
||||
assert result == os.path.realpath("/etc/passwd")
|
||||
|
||||
def test_rejection_message_redacts_absolute_prefixes(self, tmp_path):
|
||||
outside = tmp_path.parent / "outside.txt"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_file_path(str(outside), str(tmp_path))
|
||||
|
||||
message = str(exc_info.value)
|
||||
assert "outside.txt" in message
|
||||
assert str(tmp_path) not in message
|
||||
assert str(tmp_path.parent) not in message
|
||||
|
||||
|
||||
class TestFormatPathForDisplay:
|
||||
"""Tests for user-visible path labels."""
|
||||
|
||||
def test_returns_relative_path_inside_base(self, tmp_path):
|
||||
nested_file = tmp_path / "nested" / "file.txt"
|
||||
nested_file.parent.mkdir()
|
||||
nested_file.touch()
|
||||
|
||||
result = format_path_for_display(str(nested_file), str(tmp_path))
|
||||
|
||||
assert result == os.path.join("nested", "file.txt")
|
||||
|
||||
def test_redacts_absolute_prefix_outside_base(self, tmp_path):
|
||||
outside_file = tmp_path.parent / "outside.txt"
|
||||
|
||||
result = format_path_for_display(str(outside_file), str(tmp_path))
|
||||
|
||||
assert result == "outside.txt"
|
||||
|
||||
|
||||
class TestValidateDirectoryPath:
|
||||
"""Tests for validate_directory_path."""
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
"appdirs~=1.4.4",
|
||||
"jsonref~=1.1.0",
|
||||
"json-repair~=0.25.2",
|
||||
"cel-python>=0.5.0,<0.6",
|
||||
"tomli-w~=1.1.0",
|
||||
"tomli~=2.0.2",
|
||||
"json5~=0.10.0",
|
||||
|
||||
@@ -158,7 +158,6 @@ class EventListener(BaseEventListener):
|
||||
trace_listener.formatter = self.formatter
|
||||
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
|
||||
@crewai_event_bus.on(CCEnvEvent)
|
||||
def on_cc_env(_: Any, event: CCEnvEvent) -> None:
|
||||
self._telemetry.env_context_span(event.type)
|
||||
|
||||
@@ -47,7 +47,7 @@ from crewai.flow.conversation import (
|
||||
receive_user_message as _receive_user_message,
|
||||
)
|
||||
from crewai.flow.dsl import listen, start
|
||||
from crewai.flow.dsl._utils import _set_flow_method_definition
|
||||
from crewai.flow.dsl._utils import _method_action, _set_flow_method_definition
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
@@ -78,7 +78,7 @@ def _conversation_start_router(func: Callable[..., Any]) -> Any:
|
||||
wrapper = start()(func)
|
||||
_set_flow_method_definition(
|
||||
cast(Any, wrapper),
|
||||
FlowMethodDefinition(start=True, router=True),
|
||||
FlowMethodDefinition(do=_method_action(func), start=True, router=True),
|
||||
)
|
||||
return wrapper
|
||||
|
||||
@@ -146,6 +146,10 @@ class _ConversationalMixin:
|
||||
def kickoff(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> list[Any]:
|
||||
pass
|
||||
|
||||
def conversation_start(self) -> str | None:
|
||||
"""Return the current user message for conversational route selection.
|
||||
|
||||
@@ -1033,7 +1037,8 @@ class _ConversationalMixin:
|
||||
# of warning about an empty scope stack.
|
||||
started_id = getattr(self, "_deferred_flow_started_event_id", None)
|
||||
if started_id:
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
method_outputs = self.method_outputs
|
||||
last_output = method_outputs[-1] if method_outputs else None
|
||||
restore_event_scope(((started_id, "flow_started"),))
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
|
||||
@@ -3,11 +3,10 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.human_feedback import (
|
||||
HumanFeedbackConfig,
|
||||
HumanFeedbackResult,
|
||||
_build_human_feedback_runtime_decorator,
|
||||
_validate_human_feedback_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,32 +20,6 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
__all__ = ["HumanFeedbackResult", "human_feedback"]
|
||||
|
||||
|
||||
def _stamp_human_feedback_metadata(
|
||||
wrapper: Any,
|
||||
func: Callable[..., Any],
|
||||
config: HumanFeedbackConfig,
|
||||
) -> None:
|
||||
for attr in [
|
||||
"__is_flow_method__",
|
||||
"__flow_persistence_config__",
|
||||
"__flow_method_definition__",
|
||||
]:
|
||||
if hasattr(func, attr):
|
||||
setattr(wrapper, attr, getattr(func, attr))
|
||||
|
||||
wrapper.__human_feedback_config__ = config
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
if config.emit:
|
||||
fragment = getattr(wrapper, "__flow_method_definition__", None)
|
||||
if isinstance(fragment, FlowMethodDefinition):
|
||||
wrapper.__flow_method_definition__ = fragment.model_copy(
|
||||
update={"router": True, "emit": list(config.emit)}
|
||||
)
|
||||
|
||||
wrapper._human_feedback_llm = config.llm
|
||||
|
||||
|
||||
def human_feedback(
|
||||
message: str,
|
||||
emit: Sequence[str] | None = None,
|
||||
@@ -58,21 +31,18 @@ def human_feedback(
|
||||
learn_source: str = "hitl",
|
||||
learn_strict: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator for Flow methods that require human feedback."""
|
||||
runtime_decorator = _build_human_feedback_runtime_decorator(
|
||||
message=message,
|
||||
emit=emit,
|
||||
llm=llm,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
learn=learn,
|
||||
learn_source=learn_source,
|
||||
learn_strict=learn_strict,
|
||||
"""Decorator for Flow methods that require human feedback.
|
||||
|
||||
The decorator is a pure metadata stamper: it records the feedback
|
||||
configuration on the method, and the Flow engine collects and routes
|
||||
feedback after the method completes, driven by the flow's definition.
|
||||
"""
|
||||
_validate_human_feedback_options(
|
||||
emit=emit, llm=llm, default_outcome=default_outcome
|
||||
)
|
||||
config = HumanFeedbackConfig(
|
||||
message=message,
|
||||
emit=emit,
|
||||
emit=list(emit) if emit is not None else None,
|
||||
llm=llm,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata,
|
||||
@@ -83,8 +53,7 @@ def human_feedback(
|
||||
)
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
wrapper = runtime_decorator(func)
|
||||
_stamp_human_feedback_metadata(wrapper, func, config)
|
||||
return wrapper
|
||||
func.__human_feedback_config__ = config # type: ignore[attr-defined]
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -8,6 +8,7 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
@@ -45,7 +46,11 @@ def listen(condition: FlowTrigger) -> FlowMethodDecorator:
|
||||
wrapper = ListenMethod(func)
|
||||
|
||||
_set_flow_method_definition(
|
||||
wrapper, FlowMethodDefinition(listen=_to_definition_condition(condition))
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
listen=_to_definition_condition(condition),
|
||||
),
|
||||
)
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
@@ -148,6 +149,7 @@ def router(
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
listen=_to_definition_condition(condition),
|
||||
router=True,
|
||||
emit=router_events or None,
|
||||
|
||||
@@ -8,6 +8,7 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
@@ -53,13 +54,17 @@ def start(
|
||||
def decorator(func: Callable[P, R]) -> StartMethod[P, R]:
|
||||
wrapper = StartMethod(func)
|
||||
|
||||
if condition is not None:
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(start=_to_definition_condition(condition)),
|
||||
)
|
||||
else:
|
||||
_set_flow_method_definition(wrapper, FlowMethodDefinition(start=True))
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
start=(
|
||||
_to_definition_condition(condition)
|
||||
if condition is not None
|
||||
else True
|
||||
),
|
||||
),
|
||||
)
|
||||
return wrapper
|
||||
|
||||
return cast(FlowMethodDecorator, decorator)
|
||||
|
||||
@@ -8,6 +8,8 @@ from pydantic import BaseModel
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowCodeActionDefinition,
|
||||
FlowConfigDefinition,
|
||||
FlowConversationalDefinition,
|
||||
FlowConversationalRouterDefinition,
|
||||
@@ -17,6 +19,7 @@ from crewai.flow.flow_definition import (
|
||||
FlowMethodDefinition,
|
||||
FlowPersistenceDefinition,
|
||||
FlowStateDefinition,
|
||||
_object_ref,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowMethod,
|
||||
@@ -34,15 +37,12 @@ _FLOW_METHOD_METADATA_ATTRS = [
|
||||
"__flow_method_definition__",
|
||||
"__flow_persistence_config__",
|
||||
"__human_feedback_config__",
|
||||
"_human_feedback_llm",
|
||||
]
|
||||
|
||||
|
||||
def is_flow_method(obj: Any) -> TypeIs[FlowMethod[Any, Any]]:
|
||||
"""Check if the object carries Flow method wrapper metadata."""
|
||||
return hasattr(obj, "__is_flow_method__") or hasattr(
|
||||
obj, _FLOW_METHOD_DEFINITION_ATTR
|
||||
)
|
||||
return hasattr(obj, _FLOW_METHOD_DEFINITION_ATTR)
|
||||
|
||||
|
||||
def _should_include_flow_method(flow_class: type, method: Any) -> bool:
|
||||
@@ -80,10 +80,13 @@ def _stamp_inherited_conversational_metadata(
|
||||
for attr in _FLOW_METHOD_METADATA_ATTRS:
|
||||
if hasattr(inherited, attr):
|
||||
setattr(method, attr, getattr(inherited, attr))
|
||||
method.__is_flow_method__ = True
|
||||
return method
|
||||
|
||||
|
||||
def _method_action(method: Any) -> FlowActionDefinition:
|
||||
return FlowCodeActionDefinition(ref=f"{method.__module__}:{method.__qualname__}")
|
||||
|
||||
|
||||
def _set_flow_method_definition(
|
||||
wrapper: FlowMethod[P, R],
|
||||
definition: FlowMethodDefinition,
|
||||
@@ -100,13 +103,6 @@ def _get_flow_method_definition(method: Any) -> FlowMethodDefinition | None:
|
||||
return None
|
||||
|
||||
|
||||
def _object_ref(value: Any) -> str:
|
||||
target = value if isinstance(value, type) else type(value)
|
||||
module = getattr(target, "__module__", "")
|
||||
qualname = getattr(target, "__qualname__", getattr(target, "__name__", ""))
|
||||
return f"{module}:{qualname}" if module and qualname else repr(value)
|
||||
|
||||
|
||||
def _is_json_serializable(value: Any) -> bool:
|
||||
try:
|
||||
json.dumps(value)
|
||||
@@ -214,16 +210,22 @@ def _build_config_definition(
|
||||
) -> FlowConfigDefinition:
|
||||
config_field_names = set(FlowConfigDefinition.model_fields)
|
||||
field_defaults = {
|
||||
name: field.default
|
||||
name: field.get_default(call_default_factory=True)
|
||||
for name, field in getattr(flow_class, "model_fields", {}).items()
|
||||
if name in config_field_names
|
||||
}
|
||||
values: dict[str, Any] = {}
|
||||
for field_name, default in field_defaults.items():
|
||||
value = getattr(flow_class, field_name, default)
|
||||
values[field_name] = _serialize_static_value(
|
||||
value, diagnostics, f"config.{field_name}"
|
||||
)
|
||||
if field_name == "input_provider":
|
||||
# A string value is already a ref; only live objects degrade.
|
||||
values[field_name] = (
|
||||
value if value is None or isinstance(value, str) else _object_ref(value)
|
||||
)
|
||||
else:
|
||||
values[field_name] = _serialize_static_value(
|
||||
value, diagnostics, f"config.{field_name}"
|
||||
)
|
||||
return FlowConfigDefinition(**values)
|
||||
|
||||
|
||||
@@ -239,38 +241,31 @@ def _build_human_feedback_definition(
|
||||
return FlowHumanFeedbackDefinition(
|
||||
message=str(config.message),
|
||||
emit=[str(value) for value in emit] if emit is not None else None,
|
||||
llm=_serialize_static_value(
|
||||
getattr(config, "llm", None), diagnostics, f"{path}.llm"
|
||||
),
|
||||
# llm and provider stay live: the engine consumes them in-process and
|
||||
# the contract degrades them to serializable forms at JSON dump time.
|
||||
llm=getattr(config, "llm", None),
|
||||
default_outcome=getattr(config, "default_outcome", None),
|
||||
metadata=_serialize_static_value(
|
||||
getattr(config, "metadata", None), diagnostics, f"{path}.metadata"
|
||||
),
|
||||
provider=_serialize_static_value(
|
||||
getattr(config, "provider", None), diagnostics, f"{path}.provider"
|
||||
),
|
||||
provider=getattr(config, "provider", None),
|
||||
learn=bool(getattr(config, "learn", False)),
|
||||
learn_source=str(getattr(config, "learn_source", "hitl")),
|
||||
learn_strict=bool(getattr(config, "learn_strict", False)),
|
||||
)
|
||||
|
||||
|
||||
def _build_persistence_definition(
|
||||
value: Any,
|
||||
diagnostics: list[FlowDefinitionDiagnostic],
|
||||
path: str,
|
||||
) -> FlowPersistenceDefinition | None:
|
||||
def _build_persistence_definition(value: Any) -> FlowPersistenceDefinition | None:
|
||||
config = getattr(value, "__flow_persistence_config__", None)
|
||||
if config is None:
|
||||
return None
|
||||
persistence = getattr(config, "persistence", None)
|
||||
verbose = bool(getattr(config, "verbose", False))
|
||||
return FlowPersistenceDefinition(
|
||||
enabled=True,
|
||||
verbose=verbose,
|
||||
persistence=_serialize_static_value(
|
||||
persistence, diagnostics, f"{path}.persistence"
|
||||
),
|
||||
verbose=bool(getattr(config, "verbose", False)),
|
||||
# The backend stays live: the engine persists through the exact
|
||||
# instance the user configured; the contract degrades it to a
|
||||
# serialized config at JSON dump time.
|
||||
persistence=getattr(config, "persistence", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -373,9 +368,11 @@ def _build_method_definition(
|
||||
) -> FlowMethodDefinition:
|
||||
fragment = _get_flow_method_definition(method)
|
||||
if fragment is None:
|
||||
method_definition = FlowMethodDefinition()
|
||||
method_definition = FlowMethodDefinition(do=_method_action(method))
|
||||
else:
|
||||
method_definition = fragment.model_copy(deep=True)
|
||||
method_definition = fragment.model_copy(
|
||||
deep=True, update={"do": _method_action(method)}
|
||||
)
|
||||
|
||||
human_feedback = _build_human_feedback_definition(
|
||||
method, diagnostics, f"{path}.human_feedback"
|
||||
@@ -386,9 +383,7 @@ def _build_method_definition(
|
||||
method_definition.router = True
|
||||
method_definition.emit = None
|
||||
|
||||
method_definition.persist = _build_persistence_definition(
|
||||
method, diagnostics, f"{path}.persist"
|
||||
)
|
||||
method_definition.persist = _build_persistence_definition(method)
|
||||
|
||||
return method_definition
|
||||
|
||||
@@ -472,7 +467,7 @@ def _build_flow_definition_from_class(
|
||||
description=description,
|
||||
state=_build_state_definition(flow_class, diagnostics),
|
||||
config=_build_config_definition(flow_class, diagnostics),
|
||||
persist=_build_persistence_definition(flow_class, diagnostics, "persist"),
|
||||
persist=_build_persistence_definition(flow_class),
|
||||
conversational=_build_conversational_definition(flow_class, diagnostics),
|
||||
methods=methods,
|
||||
diagnostics=diagnostics,
|
||||
|
||||
@@ -11,9 +11,17 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal as TypingLiteral
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
import yaml
|
||||
|
||||
from crewai.flow.conversational_definition import (
|
||||
@@ -25,21 +33,36 @@ from crewai.flow.conversational_definition import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FlowDefinitionCondition = str | dict[str, Any]
|
||||
_STEP_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
__all__ = [
|
||||
"FlowActionDefinition",
|
||||
"FlowCodeActionDefinition",
|
||||
"FlowConfigDefinition",
|
||||
"FlowConversationalDefinition",
|
||||
"FlowConversationalRouterDefinition",
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
"FlowPersistenceDefinition",
|
||||
"FlowStateDefinition",
|
||||
"FlowToolActionDefinition",
|
||||
]
|
||||
|
||||
|
||||
def _object_ref(value: Any) -> str:
|
||||
"""Format a class or instance as the canonical ``module:qualname`` ref."""
|
||||
target = value if isinstance(value, type) else type(value)
|
||||
module = getattr(target, "__module__", "")
|
||||
qualname = getattr(target, "__qualname__", getattr(target, "__name__", ""))
|
||||
return f"{module}:{qualname}" if module and qualname else repr(value)
|
||||
|
||||
|
||||
class FlowDefinitionDiagnostic(BaseModel):
|
||||
"""A non-fatal Flow Definition build or validation diagnostic."""
|
||||
|
||||
@@ -52,9 +75,10 @@ class FlowDefinitionDiagnostic(BaseModel):
|
||||
class FlowStateDefinition(BaseModel):
|
||||
"""Static description of a Flow state contract."""
|
||||
|
||||
type: TypingLiteral["dict", "pydantic", "unknown"] = "dict"
|
||||
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
|
||||
ref: str | None = None
|
||||
default: Any = None
|
||||
json_schema: dict[str, Any] | None = None
|
||||
default: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FlowConfigDefinition(BaseModel):
|
||||
@@ -62,22 +86,50 @@ class FlowConfigDefinition(BaseModel):
|
||||
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = None
|
||||
input_provider: Any = None
|
||||
memory: dict[str, Any] | None = None
|
||||
input_provider: str | None = None
|
||||
suppress_flow_events: bool = False
|
||||
max_method_calls: int = 100
|
||||
defer_trace_finalization: bool = False
|
||||
checkpoint: bool | dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FlowPersistenceDefinition(BaseModel):
|
||||
"""Static persistence configuration."""
|
||||
"""Static persistence configuration.
|
||||
|
||||
``persistence`` may hold a live backend when the definition is built from
|
||||
a decorated class — the engine then persists through the exact instance
|
||||
the user configured; the JSON/YAML projection degrades it to its
|
||||
serialized config.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
verbose: bool = False
|
||||
persistence: Any = None
|
||||
|
||||
@field_serializer("persistence", when_used="json")
|
||||
def _serialize_persistence(self, value: Any) -> Any:
|
||||
if value is None or isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, BaseModel):
|
||||
try:
|
||||
return value.model_dump(mode="json")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Persistence backend %s is not fully serializable; "
|
||||
"preserved import reference only.",
|
||||
_object_ref(value),
|
||||
)
|
||||
return {"ref": _object_ref(value)}
|
||||
|
||||
|
||||
class FlowHumanFeedbackDefinition(BaseModel):
|
||||
"""Static human feedback configuration."""
|
||||
"""Static human feedback configuration.
|
||||
|
||||
``llm`` and ``provider`` may hold live Python objects when the definition
|
||||
is built from a decorated class; the JSON/YAML projection degrades them to
|
||||
a serialized config (``llm``) or a ``module:qualname`` ref (``provider``).
|
||||
"""
|
||||
|
||||
message: str
|
||||
emit: list[str] | None = None
|
||||
@@ -89,10 +141,120 @@ class FlowHumanFeedbackDefinition(BaseModel):
|
||||
learn_source: str = "hitl"
|
||||
learn_strict: bool = False
|
||||
|
||||
@field_serializer("llm", when_used="json")
|
||||
def _serialize_llm(self, value: Any) -> dict[str, Any] | str | None:
|
||||
if value is None or isinstance(value, (str, dict)):
|
||||
return value
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
|
||||
return _serialize_llm_for_context(value)
|
||||
|
||||
@field_serializer("provider", when_used="json")
|
||||
def _serialize_provider(self, value: Any) -> str | None:
|
||||
if value is None or isinstance(value, str):
|
||||
return value
|
||||
return _object_ref(value)
|
||||
|
||||
|
||||
class FlowCodeActionDefinition(BaseModel):
|
||||
"""A Flow method action that executes importable Python code."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["code"] = "code"
|
||||
ref: str
|
||||
with_: dict[str, Any] | None = Field(default=None, alias="with")
|
||||
|
||||
|
||||
class FlowToolActionDefinition(BaseModel):
|
||||
"""A Flow method action that invokes a CrewAI tool."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["tool"]
|
||||
ref: str
|
||||
with_: dict[str, Any] | None = Field(default=None, alias="with")
|
||||
|
||||
|
||||
class FlowExpressionActionDefinition(BaseModel):
|
||||
"""A Flow method action that evaluates a CEL expression."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
call: TypingLiteral["expression"]
|
||||
expr: str
|
||||
|
||||
|
||||
FlowInnerActionDefinition = (
|
||||
FlowCodeActionDefinition | FlowToolActionDefinition | FlowExpressionActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowEachInnerActionDefinition(RootModel[dict[str, FlowInnerActionDefinition]]):
|
||||
"""One named action inside an ``each`` composite action."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return next(iter(self.root))
|
||||
|
||||
@property
|
||||
def action(self) -> FlowInnerActionDefinition:
|
||||
return next(iter(self.root.values()))
|
||||
|
||||
|
||||
class FlowEachActionDefinition(BaseModel):
|
||||
"""A composite action that runs a sequential mini-pipeline for each item."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["each"]
|
||||
in_: str = Field(alias="in")
|
||||
do: list[FlowEachInnerActionDefinition]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_inner_action_list(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict) or "do" not in data:
|
||||
return data
|
||||
|
||||
inner_actions = data["do"]
|
||||
if not isinstance(inner_actions, list) or not inner_actions:
|
||||
raise ValueError("each.do must contain at least one action")
|
||||
|
||||
seen: set[str] = set()
|
||||
for inner_action in inner_actions:
|
||||
if isinstance(inner_action, FlowEachInnerActionDefinition):
|
||||
action_mapping = inner_action.root
|
||||
elif isinstance(inner_action, dict):
|
||||
action_mapping = inner_action
|
||||
else:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
if len(action_mapping) != 1:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
name = next(iter(action_mapping))
|
||||
_validate_step_name(name, field="each.do action names")
|
||||
if name in seen:
|
||||
raise ValueError(f"each.do action names must be unique: {name!r}")
|
||||
seen.add(name)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
FlowActionDefinition = (
|
||||
FlowCodeActionDefinition
|
||||
| FlowToolActionDefinition
|
||||
| FlowExpressionActionDefinition
|
||||
| FlowEachActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowMethodDefinition(BaseModel):
|
||||
"""Static definition of one Flow method and its execution roles."""
|
||||
|
||||
description: str | None = None
|
||||
do: FlowActionDefinition
|
||||
start: bool | FlowDefinitionCondition | None = None
|
||||
listen: FlowDefinitionCondition | None = None
|
||||
router: bool = False
|
||||
@@ -100,6 +262,16 @@ class FlowMethodDefinition(BaseModel):
|
||||
human_feedback: FlowHumanFeedbackDefinition | None = None
|
||||
persist: FlowPersistenceDefinition | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _canonicalize_human_feedback_routing(self) -> FlowMethodDefinition:
|
||||
# Canonical shape: a method whose human_feedback declares emit
|
||||
# outcomes routes like a router, regardless of how the definition
|
||||
# was authored.
|
||||
if self.human_feedback is not None and self.human_feedback.emit:
|
||||
self.router = True
|
||||
self.emit = None
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_start(self) -> bool:
|
||||
"""Whether this method is a start method.
|
||||
@@ -116,7 +288,9 @@ class FlowDefinition(BaseModel):
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
schema_: str = Field(default="crewai.flow/v1", alias="schema")
|
||||
schema_: TypingLiteral["crewai.flow/v1"] = Field(
|
||||
default="crewai.flow/v1", alias="schema"
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
state: FlowStateDefinition | None = None
|
||||
@@ -126,6 +300,12 @@ class FlowDefinition(BaseModel):
|
||||
methods: dict[str, FlowMethodDefinition] = Field(default_factory=dict)
|
||||
diagnostics: list[FlowDefinitionDiagnostic] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_method_names(self) -> FlowDefinition:
|
||||
for method_name in self.methods:
|
||||
_validate_step_name(method_name, field="Flow method names")
|
||||
return self
|
||||
|
||||
def to_dict(self, *, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize the definition to a JSON/YAML-ready dictionary."""
|
||||
return self.model_dump(by_alias=True, exclude_none=exclude_none, mode="json")
|
||||
@@ -268,6 +448,11 @@ def _deserialize_diagnostics(value: Any) -> list[FlowDefinitionDiagnostic]:
|
||||
return [FlowDefinitionDiagnostic.model_validate(item) for item in value or []]
|
||||
|
||||
|
||||
def _validate_step_name(name: str, *, field: str) -> None:
|
||||
if not isinstance(name, str) or not _STEP_NAME_PATTERN.fullmatch(name):
|
||||
raise ValueError(f"{field} must match {_STEP_NAME_PATTERN.pattern}")
|
||||
|
||||
|
||||
def _merge_diagnostics(
|
||||
*diagnostic_groups: list[FlowDefinitionDiagnostic],
|
||||
) -> list[FlowDefinitionDiagnostic]:
|
||||
|
||||
@@ -83,7 +83,6 @@ class FlowMethod(Generic[P, R]):
|
||||
"__conversational_only__", # gates registration on Flow.conversational
|
||||
"__flow_persistence_config__",
|
||||
"__flow_method_definition__",
|
||||
"_human_feedback_llm", # Live LLM object for HITL resume
|
||||
]:
|
||||
if hasattr(meth, attr):
|
||||
setattr(self, attr, getattr(meth, attr))
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Human feedback decorator for Flow methods.
|
||||
"""Human feedback support for Flow methods.
|
||||
|
||||
This module provides the @human_feedback decorator that enables human-in-the-loop
|
||||
workflows within CrewAI Flows. It allows collecting human feedback on method outputs
|
||||
and optionally routing to different listeners based on the feedback.
|
||||
This module backs the @human_feedback decorator that enables human-in-the-loop
|
||||
workflows within CrewAI Flows. The decorator is a pure metadata stamper: it
|
||||
records a :class:`HumanFeedbackConfig` on the method, the Flow definition
|
||||
builder lifts it into ``FlowHumanFeedbackDefinition``, and the Flow engine
|
||||
collects feedback after each decorated method completes, driven by the flow's
|
||||
definition.
|
||||
|
||||
Supports both synchronous (blocking) and asynchronous (non-blocking) feedback
|
||||
collection through the provider parameter.
|
||||
@@ -55,22 +58,18 @@ Example (asynchronous with custom provider):
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.flow_wrappers import FlowMethod
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackProvider
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.flow.runtime import Flow
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
@@ -160,8 +159,8 @@ class HumanFeedbackResult:
|
||||
class HumanFeedbackConfig:
|
||||
"""Configuration for the @human_feedback decorator.
|
||||
|
||||
Stores the parameters passed to the decorator for later use during
|
||||
method execution and for introspection by visualization tools.
|
||||
Stores the parameters passed to the decorator for later use by the
|
||||
Flow definition builder and for introspection by visualization tools.
|
||||
|
||||
Attributes:
|
||||
message: The message shown to the human when requesting feedback.
|
||||
@@ -183,19 +182,6 @@ class HumanFeedbackConfig:
|
||||
learn_strict: bool = False
|
||||
|
||||
|
||||
class HumanFeedbackMethod(FlowMethod[Any, Any]):
|
||||
"""Wrapper for methods decorated with @human_feedback.
|
||||
|
||||
This wrapper extends FlowMethod to add human feedback specific attributes
|
||||
used by the FlowDefinition builder and runtime feedback handling.
|
||||
|
||||
Attributes:
|
||||
__human_feedback_config__: The HumanFeedbackConfig for this method.
|
||||
"""
|
||||
|
||||
__human_feedback_config__: HumanFeedbackConfig | None = None
|
||||
|
||||
|
||||
class PreReviewResult(BaseModel):
|
||||
"""Structured output from the HITL pre-review LLM call."""
|
||||
|
||||
@@ -217,17 +203,11 @@ class DistilledLessons(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _build_human_feedback_runtime_decorator(
|
||||
message: str,
|
||||
emit: Sequence[str] | None = None,
|
||||
llm: str | BaseLLM | None = "gpt-4o-mini",
|
||||
default_outcome: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider: HumanFeedbackProvider | None = None,
|
||||
learn: bool = False,
|
||||
learn_source: str = "hitl",
|
||||
learn_strict: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def _validate_human_feedback_options(
|
||||
emit: Sequence[str] | None,
|
||||
llm: Any,
|
||||
default_outcome: str | None,
|
||||
) -> None:
|
||||
if emit is not None:
|
||||
if not llm:
|
||||
raise ValueError(
|
||||
@@ -244,295 +224,139 @@ def _build_human_feedback_runtime_decorator(
|
||||
elif default_outcome is not None:
|
||||
raise ValueError("default_outcome requires emit to be specified.")
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
def _get_hitl_prompt(key: str) -> str:
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
return I18N_DEFAULT.slice(key)
|
||||
def _get_hitl_prompt(key: str) -> str:
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
|
||||
def _resolve_llm_instance() -> Any:
|
||||
if llm is None:
|
||||
from crewai.llm import LLM
|
||||
return I18N_DEFAULT.slice(key)
|
||||
|
||||
return LLM(model="gpt-4o-mini")
|
||||
if isinstance(llm, str):
|
||||
from crewai.llm import LLM
|
||||
|
||||
return LLM(model=llm)
|
||||
return llm # already a BaseLLM instance
|
||||
def _resolve_llm_instance(llm: Any) -> Any:
|
||||
from crewai.llm import LLM
|
||||
|
||||
def _pre_review_with_lessons(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> Any:
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return method_output
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = mem.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
if llm is None:
|
||||
return LLM(model="gpt-4o-mini")
|
||||
if isinstance(llm, str):
|
||||
return LLM(model=llm)
|
||||
if isinstance(llm, dict):
|
||||
deserialized = _deserialize_llm_from_context(llm)
|
||||
return deserialized if deserialized is not None else LLM(model="gpt-4o-mini")
|
||||
return llm # already a BaseLLM instance
|
||||
|
||||
lessons = "\n".join(f"- {m.record.content}" for m in matches)
|
||||
llm_inst = _resolve_llm_instance()
|
||||
prompt = _get_hitl_prompt("hitl_pre_review_user").format(
|
||||
output=str(method_output),
|
||||
lessons=lessons,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
response = llm_inst.call(messages, response_model=PreReviewResult)
|
||||
if isinstance(response, PreReviewResult):
|
||||
return response.improved_output
|
||||
return PreReviewResult.model_validate(response).improved_output
|
||||
reviewed = llm_inst.call(messages)
|
||||
return reviewed if isinstance(reviewed, str) else str(reviewed)
|
||||
except Exception:
|
||||
if learn_strict:
|
||||
logger.warning(
|
||||
"HITL pre-review failed for %s; re-raising (learn_strict=True)",
|
||||
func.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
logger.warning(
|
||||
"HITL pre-review failed for %s; falling back to raw output",
|
||||
func.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
return method_output
|
||||
|
||||
def _distill_and_store_lessons(
|
||||
flow_instance: Flow[Any], method_output: Any, raw_feedback: str
|
||||
) -> None:
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return
|
||||
llm_inst = _resolve_llm_instance()
|
||||
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
||||
method_name=func.__name__,
|
||||
output=str(method_output),
|
||||
feedback=raw_feedback,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
def _pre_review_with_lessons(
|
||||
flow_instance: Flow[Any],
|
||||
method_name: str,
|
||||
method_output: Any,
|
||||
*,
|
||||
llm: Any,
|
||||
learn_source: str,
|
||||
learn_strict: bool,
|
||||
) -> Any:
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return method_output
|
||||
query = f"human feedback lessons for {method_name}: {method_output!s}"
|
||||
matches = mem.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
lessons: list[str] = []
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
response = llm_inst.call(messages, response_model=DistilledLessons)
|
||||
if isinstance(response, DistilledLessons):
|
||||
lessons = response.lessons
|
||||
else:
|
||||
lessons = DistilledLessons.model_validate(response).lessons
|
||||
else:
|
||||
response = llm_inst.call(messages)
|
||||
if isinstance(response, str):
|
||||
lessons = [
|
||||
line.strip("- ").strip()
|
||||
for line in response.strip().split("\n")
|
||||
if line.strip() and line.strip() != "NONE"
|
||||
]
|
||||
|
||||
if lessons:
|
||||
mem.remember_many(lessons, source=learn_source) # type: ignore[union-attr]
|
||||
except Exception:
|
||||
if learn_strict:
|
||||
logger.warning(
|
||||
"HITL lesson distillation failed for %s; re-raising (learn_strict=True)",
|
||||
func.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
logger.warning(
|
||||
"HITL lesson distillation failed for %s; no lessons stored",
|
||||
func.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _build_feedback_context(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> tuple[Any, Any]:
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id=flow_instance.flow_id or "unknown",
|
||||
flow_class=f"{flow_instance.__class__.__module__}.{flow_instance.__class__.__name__}",
|
||||
method_name=func.__name__,
|
||||
method_output=method_output,
|
||||
message=message,
|
||||
emit=list(emit) if emit else None,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata or {},
|
||||
llm=llm if isinstance(llm, str) else _serialize_llm_for_context(llm),
|
||||
lessons = "\n".join(f"- {m.record.content}" for m in matches)
|
||||
llm_inst = _resolve_llm_instance(llm)
|
||||
prompt = _get_hitl_prompt("hitl_pre_review_user").format(
|
||||
output=str(method_output),
|
||||
lessons=lessons,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
response = llm_inst.call(messages, response_model=PreReviewResult)
|
||||
if isinstance(response, PreReviewResult):
|
||||
return response.improved_output
|
||||
return PreReviewResult.model_validate(response).improved_output
|
||||
reviewed = llm_inst.call(messages)
|
||||
return reviewed if isinstance(reviewed, str) else str(reviewed)
|
||||
except Exception:
|
||||
if learn_strict:
|
||||
logger.warning(
|
||||
"HITL pre-review failed for %s; re-raising (learn_strict=True)",
|
||||
method_name,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
logger.warning(
|
||||
"HITL pre-review failed for %s; falling back to raw output",
|
||||
method_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return method_output
|
||||
|
||||
effective_provider = provider
|
||||
if effective_provider is None:
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
effective_provider = flow_config.hitl_provider
|
||||
def _distill_and_store_lessons(
|
||||
flow_instance: Flow[Any],
|
||||
method_name: str,
|
||||
method_output: Any,
|
||||
raw_feedback: str,
|
||||
*,
|
||||
llm: Any,
|
||||
learn_source: str,
|
||||
learn_strict: bool,
|
||||
) -> None:
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return
|
||||
llm_inst = _resolve_llm_instance(llm)
|
||||
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
||||
method_name=method_name,
|
||||
output=str(method_output),
|
||||
feedback=raw_feedback,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
return context, effective_provider
|
||||
|
||||
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
|
||||
context, effective_provider = _build_feedback_context(
|
||||
flow_instance, method_output
|
||||
)
|
||||
|
||||
if effective_provider is not None:
|
||||
feedback_result = effective_provider.request_feedback(
|
||||
context, flow_instance
|
||||
)
|
||||
if asyncio.iscoroutine(feedback_result):
|
||||
raise TypeError(
|
||||
f"Provider {type(effective_provider).__name__}.request_feedback() "
|
||||
"returned a coroutine in a sync flow method. Use an async flow "
|
||||
"method or a synchronous provider."
|
||||
)
|
||||
return str(feedback_result)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
metadata=metadata,
|
||||
emit=emit,
|
||||
)
|
||||
|
||||
async def _request_feedback_async(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> str:
|
||||
context, effective_provider = _build_feedback_context(
|
||||
flow_instance, method_output
|
||||
)
|
||||
|
||||
if effective_provider is not None:
|
||||
feedback_result = effective_provider.request_feedback(
|
||||
context, flow_instance
|
||||
)
|
||||
if asyncio.iscoroutine(feedback_result):
|
||||
return str(await feedback_result)
|
||||
return str(feedback_result)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
metadata=metadata,
|
||||
emit=emit,
|
||||
)
|
||||
|
||||
def _process_feedback(
|
||||
flow_instance: Flow[Any],
|
||||
method_output: Any,
|
||||
raw_feedback: str,
|
||||
) -> HumanFeedbackResult | str:
|
||||
collapsed_outcome: str | None = None
|
||||
|
||||
if not raw_feedback.strip():
|
||||
if default_outcome:
|
||||
collapsed_outcome = default_outcome
|
||||
elif emit:
|
||||
collapsed_outcome = emit[0]
|
||||
elif emit:
|
||||
if llm is not None:
|
||||
collapsed_outcome = flow_instance._collapse_to_outcome(
|
||||
feedback=raw_feedback,
|
||||
outcomes=emit,
|
||||
llm=llm,
|
||||
)
|
||||
else:
|
||||
collapsed_outcome = emit[0]
|
||||
|
||||
result = HumanFeedbackResult(
|
||||
output=method_output,
|
||||
feedback=raw_feedback,
|
||||
outcome=collapsed_outcome,
|
||||
timestamp=datetime.now(),
|
||||
method_name=func.__name__,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
flow_instance.human_feedback_history.append(result)
|
||||
flow_instance.last_human_feedback = result
|
||||
|
||||
if emit:
|
||||
if collapsed_outcome is None:
|
||||
collapsed_outcome = default_outcome or emit[0]
|
||||
result.outcome = collapsed_outcome
|
||||
return collapsed_outcome
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
method_output = await func(self, *args, **kwargs)
|
||||
|
||||
if learn and getattr(self, "memory", None) is not None:
|
||||
method_output = _pre_review_with_lessons(self, method_output)
|
||||
|
||||
raw_feedback = await _request_feedback_async(self, method_output)
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
# Stash the real method output for final flow result when emit is set:
|
||||
# result is the collapsed outcome string for routing, but we preserve the
|
||||
# actual method output as the flow's final result. Uses per-method dict for
|
||||
# concurrency safety and to handle None returns.
|
||||
if emit:
|
||||
self._human_feedback_method_outputs[func.__name__] = method_output
|
||||
|
||||
return result
|
||||
|
||||
wrapper: Any = async_wrapper
|
||||
lessons: list[str] = []
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
response = llm_inst.call(messages, response_model=DistilledLessons)
|
||||
if isinstance(response, DistilledLessons):
|
||||
lessons = response.lessons
|
||||
else:
|
||||
lessons = DistilledLessons.model_validate(response).lessons
|
||||
else:
|
||||
response = llm_inst.call(messages)
|
||||
if isinstance(response, str):
|
||||
lessons = [
|
||||
line.strip("- ").strip()
|
||||
for line in response.strip().split("\n")
|
||||
if line.strip() and line.strip() != "NONE"
|
||||
]
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
method_output = func(self, *args, **kwargs)
|
||||
|
||||
if learn and getattr(self, "memory", None) is not None:
|
||||
method_output = _pre_review_with_lessons(self, method_output)
|
||||
|
||||
raw_feedback = _request_feedback(self, method_output)
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
# Stash the real method output for final flow result when emit is set:
|
||||
# result is the collapsed outcome string for routing, but we preserve the
|
||||
# actual method output as the flow's final result. Uses per-method dict for
|
||||
# concurrency safety and to handle None returns.
|
||||
if emit:
|
||||
self._human_feedback_method_outputs[func.__name__] = method_output
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
return wrapper # type: ignore[no-any-return]
|
||||
|
||||
return decorator
|
||||
if lessons:
|
||||
mem.remember_many(lessons, source=learn_source) # type: ignore[union-attr]
|
||||
except Exception:
|
||||
if learn_strict:
|
||||
logger.warning(
|
||||
"HITL lesson distillation failed for %s; re-raising (learn_strict=True)",
|
||||
method_name,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
logger.warning(
|
||||
"HITL lesson distillation failed for %s; no lessons stored",
|
||||
method_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def human_feedback(
|
||||
|
||||
@@ -24,12 +24,10 @@ Example:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
||||
|
||||
from crewai_core.printer import PRINTER
|
||||
from pydantic import BaseModel
|
||||
@@ -39,7 +37,7 @@ from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,14 +64,6 @@ def _stamp_persistence_metadata(
|
||||
)
|
||||
|
||||
|
||||
_PRESERVED_FLOW_ATTRS: Final[tuple[str, ...]] = (
|
||||
"__human_feedback_config__",
|
||||
"__flow_persistence_config__",
|
||||
"__flow_method_definition__",
|
||||
"_human_feedback_llm",
|
||||
)
|
||||
|
||||
|
||||
class PersistenceDecorator:
|
||||
"""Class to handle flow state persistence with consistent logging."""
|
||||
|
||||
@@ -164,6 +154,10 @@ def persist(
|
||||
states. When applied at the method level, it persists only that method's
|
||||
state.
|
||||
|
||||
The decorator is a pure metadata stamper: it records the persistence
|
||||
configuration on the class or method, and the Flow engine saves state
|
||||
after each persisted method completes, driven by the flow's definition.
|
||||
|
||||
Args:
|
||||
persistence: Optional FlowPersistence implementation to use.
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
@@ -191,122 +185,7 @@ def persist(
|
||||
persistence if persistence is not None else default_flow_persistence()
|
||||
)
|
||||
|
||||
if isinstance(target, type):
|
||||
_stamp_persistence_metadata(target, actual_persistence, verbose)
|
||||
original_init = target.__init__ # type: ignore[misc]
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Preserve original methods' decorators
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__flow_method_definition__")
|
||||
)
|
||||
}
|
||||
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Closure captures the current name and method
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(method, actual_persistence, verbose)
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return cast(T, result)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(
|
||||
method_async_wrapper, actual_persistence, verbose
|
||||
)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(method_sync_wrapper, actual_persistence, verbose)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
_stamp_persistence_metadata(target, actual_persistence, verbose)
|
||||
return target
|
||||
|
||||
return decorator
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
48
lib/crewai/src/crewai/flow/runtime/_actions/__init__.py
Normal file
48
lib/crewai/src/crewai/flow/runtime/_actions/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Build FlowDefinition actions into live runtime callables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import ActionHandlerRegistry
|
||||
from crewai.flow.runtime._actions._code import CodeActionHandler
|
||||
from crewai.flow.runtime._actions._each import EachActionHandler
|
||||
from crewai.flow.runtime._actions._expression import ExpressionActionHandler
|
||||
from crewai.flow.runtime._actions._tool import ToolActionHandler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_action",
|
||||
]
|
||||
|
||||
|
||||
_SIMPLE_ACTION_HANDLERS = (
|
||||
CodeActionHandler(),
|
||||
ToolActionHandler(),
|
||||
ExpressionActionHandler(),
|
||||
)
|
||||
|
||||
_SIMPLE_ACTION_REGISTRY = ActionHandlerRegistry[FlowInnerActionDefinition](
|
||||
_SIMPLE_ACTION_HANDLERS
|
||||
)
|
||||
|
||||
_ACTION_REGISTRY = ActionHandlerRegistry[FlowActionDefinition](
|
||||
(
|
||||
*_SIMPLE_ACTION_HANDLERS,
|
||||
EachActionHandler(_SIMPLE_ACTION_REGISTRY),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
|
||||
"""Turn one `do:` action into the callable the flow runs for that node."""
|
||||
return _ACTION_REGISTRY.build(flow, action)
|
||||
39
lib/crewai/src/crewai/flow/runtime/_actions/_base.py
Normal file
39
lib/crewai/src/crewai/flow/runtime/_actions/_base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Shared action handler contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
ActionT = TypeVar("ActionT", bound=BaseModel)
|
||||
ResolvedAction = Callable[..., Any]
|
||||
|
||||
|
||||
class ActionHandler(Protocol[ActionT]):
|
||||
"""Handler for one concrete FlowDefinition action type."""
|
||||
|
||||
action_type: type[ActionT]
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
"""Build the callable executed by the flow."""
|
||||
|
||||
|
||||
class ActionHandlerRegistry(Generic[ActionT]):
|
||||
"""Build action callables with an ordered set of typed handlers."""
|
||||
|
||||
def __init__(self, handlers: Iterable[ActionHandler[Any]]) -> None:
|
||||
self._handlers = tuple(handlers)
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
for handler in self._handlers:
|
||||
if isinstance(action, handler.action_type):
|
||||
return handler.build(flow, action)
|
||||
call = getattr(action, "call", None)
|
||||
raise ValueError(f"unknown call type {call!r}")
|
||||
51
lib/crewai/src/crewai/flow/runtime/_actions/_code.py
Normal file
51
lib/crewai/src/crewai/flow/runtime/_actions/_code.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Handler for ``call: code`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowCodeActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class CodeActionHandler:
|
||||
"""Build importable Python callables and bind them to the running flow."""
|
||||
|
||||
action_type = FlowCodeActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> ResolvedAction:
|
||||
handler = _resolve_code_handler(flow, action)
|
||||
|
||||
def run_code(*args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
if action.with_ is None:
|
||||
return handler(*args, **kwargs)
|
||||
return handler(
|
||||
**render_with_block(flow, action.with_, local_context=local_context)
|
||||
)
|
||||
|
||||
return functools.update_wrapper(run_code, handler)
|
||||
|
||||
|
||||
def _resolve_code_handler(
|
||||
flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = resolve_ref(ref, field="do")
|
||||
if not callable(target):
|
||||
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
|
||||
handler = cast(Callable[..., Any], target)
|
||||
if getattr(handler, "__self__", None) is None:
|
||||
handler = handler.__get__(flow, type(flow))
|
||||
return handler
|
||||
73
lib/crewai/src/crewai/flow/runtime/_actions/_each.py
Normal file
73
lib/crewai/src/crewai/flow/runtime/_actions/_each.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Handler for ``call: each`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowEachActionDefinition,
|
||||
FlowEachInnerActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import (
|
||||
ActionHandlerRegistry,
|
||||
ResolvedAction,
|
||||
)
|
||||
from crewai.flow.runtime._actions._runtime import (
|
||||
LOCAL_CONTEXT_KWARG,
|
||||
ensure_array,
|
||||
invoke_callable,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class EachActionHandler:
|
||||
"""Build a sequential mini-pipeline for every item in an array."""
|
||||
|
||||
action_type = FlowEachActionDefinition
|
||||
|
||||
def __init__(
|
||||
self, inner_registry: ActionHandlerRegistry[FlowInnerActionDefinition]
|
||||
) -> None:
|
||||
self._inner_registry = inner_registry
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowEachActionDefinition
|
||||
) -> ResolvedAction:
|
||||
inner_actions = [
|
||||
(inner_action.name, self._resolve_inner_action(flow, inner_action))
|
||||
for inner_action in action.do
|
||||
]
|
||||
|
||||
async def run_each(*_args: Any, **_kwargs: Any) -> list[Any]:
|
||||
items = ensure_array(evaluate_expression(flow, action.in_))
|
||||
results: list[Any] = []
|
||||
for item in items:
|
||||
local_outputs: dict[str, Any] = {}
|
||||
last_output: Any = None
|
||||
for name, run_inner_action in inner_actions:
|
||||
last_output = await run_inner_action(
|
||||
{"item": item, "outputs": local_outputs}
|
||||
)
|
||||
local_outputs[name] = last_output
|
||||
results.append(last_output)
|
||||
return results
|
||||
|
||||
return run_each
|
||||
|
||||
def _resolve_inner_action(
|
||||
self, flow: Flow[Any], inner_action: FlowEachInnerActionDefinition
|
||||
) -> Callable[[dict[str, Any]], Any]:
|
||||
run_action = self._inner_registry.build(flow, inner_action.action)
|
||||
|
||||
async def run_inner_action(local_context: dict[str, Any]) -> Any:
|
||||
return await invoke_callable(
|
||||
run_action, **{LOCAL_CONTEXT_KWARG: local_context}
|
||||
)
|
||||
|
||||
return run_inner_action
|
||||
29
lib/crewai/src/crewai/flow/runtime/_actions/_expression.py
Normal file
29
lib/crewai/src/crewai/flow/runtime/_actions/_expression.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Handler for ``call: expression`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import FlowExpressionActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ExpressionActionHandler:
|
||||
"""Build CEL expression actions."""
|
||||
|
||||
action_type = FlowExpressionActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowExpressionActionDefinition
|
||||
) -> ResolvedAction:
|
||||
def run_expression(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return evaluate_expression(flow, action.expr, local_context=local_context)
|
||||
|
||||
return run_expression
|
||||
28
lib/crewai/src/crewai/flow/runtime/_actions/_runtime.py
Normal file
28
lib/crewai/src/crewai/flow/runtime/_actions/_runtime.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Runtime helpers shared by action resolvers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
|
||||
LOCAL_CONTEXT_KWARG = "__flow_definition_local_context"
|
||||
|
||||
|
||||
async def invoke_callable(
|
||||
handler: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
result = await handler(*args, **kwargs)
|
||||
else:
|
||||
result = handler(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
def ensure_array(value: Any) -> list[Any]:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
raise ValueError("each.in must evaluate to an array")
|
||||
52
lib/crewai/src/crewai/flow/runtime/_actions/_tool.py
Normal file
52
lib/crewai/src/crewai/flow/runtime/_actions/_tool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Handler for ``call: tool`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowToolActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ToolActionHandler:
|
||||
"""Build and instantiate CrewAI tool actions."""
|
||||
|
||||
action_type = FlowToolActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowToolActionDefinition
|
||||
) -> ResolvedAction:
|
||||
target = resolve_ref(action.ref, field="do")
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
|
||||
raise InvalidRefError(
|
||||
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_cls = cast(Callable[[], BaseTool], target)
|
||||
tool = tool_cls()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
tool_kwargs = action.with_ or {}
|
||||
|
||||
def run_tool(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return tool.run(
|
||||
**render_with_block(flow, tool_kwargs, local_context=local_context)
|
||||
)
|
||||
|
||||
return run_tool
|
||||
163
lib/crewai/src/crewai/flow/runtime/_expressions.py
Normal file
163
lib/crewai/src/crewai/flow/runtime/_expressions.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Runtime expression support for FlowDefinition CEL expressions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from itertools import pairwise
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
_EXPRESSION_PATTERN = re.compile(r"\$\{([^{}]*)\}")
|
||||
|
||||
__all__ = ["FlowExpressionError", "evaluate_expression", "render_with_block"]
|
||||
|
||||
|
||||
class FlowExpressionError(ValueError):
|
||||
"""A FlowDefinition expression failed to parse or evaluate."""
|
||||
|
||||
|
||||
def render_with_block(
|
||||
flow: Flow[Any], value: Any, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Render CEL expressions inside a FlowDefinition ``with:`` payload."""
|
||||
context = _expression_context(flow, local_context=local_context)
|
||||
return _render_value(value, context)
|
||||
|
||||
|
||||
def evaluate_expression(
|
||||
flow: Flow[Any], expression: str, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Evaluate a FlowDefinition CEL expression against runtime context."""
|
||||
expression = expression.strip()
|
||||
if not expression:
|
||||
raise FlowExpressionError("empty CEL expression")
|
||||
return _eval_cel(expression, _expression_context(flow, local_context=local_context))
|
||||
|
||||
|
||||
def _expression_context(
|
||||
flow: Flow[Any], local_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
context = {
|
||||
"state": flow._copy_and_serialize_state(),
|
||||
"outputs": _outputs_by_name(flow._method_outputs),
|
||||
}
|
||||
if local_context:
|
||||
context.update(
|
||||
{key: _to_json_safe(value) for key, value in local_context.items()}
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
outputs: dict[str, Any] = {}
|
||||
for entry in method_outputs:
|
||||
method = ""
|
||||
output = entry
|
||||
if isinstance(entry, dict) and "output" in entry:
|
||||
method = str(entry.get("method", ""))
|
||||
output = entry["output"]
|
||||
outputs[method] = _to_json_safe(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def _to_json_safe(value: Any) -> Any:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
||||
return dataclasses.asdict(value)
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_json_safe(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _render_value(value: Any, context: dict[str, Any]) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _render_string(value, context)
|
||||
if isinstance(value, dict):
|
||||
return {key: _render_value(item, context) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_render_value(item, context) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _render_string(value: str, context: dict[str, Any]) -> Any:
|
||||
matches = list(_EXPRESSION_PATTERN.finditer(value))
|
||||
if not matches:
|
||||
_raise_for_invalid_interpolation(value)
|
||||
return value
|
||||
|
||||
_raise_for_literal_braces(value[: matches[0].start()])
|
||||
for previous, current in pairwise(matches):
|
||||
_raise_for_literal_braces(value[previous.end() : current.start()])
|
||||
_raise_for_literal_braces(value[matches[-1].end() :])
|
||||
|
||||
if len(matches) == 1 and matches[0].span() == (0, len(value)):
|
||||
expression = matches[0].group(1).strip()
|
||||
if not expression:
|
||||
raise FlowExpressionError("empty CEL expression in with block")
|
||||
return _eval_cel(expression, context)
|
||||
|
||||
rendered: list[str] = []
|
||||
position = 0
|
||||
for match in matches:
|
||||
start, end = match.span()
|
||||
literal = value[position:start]
|
||||
rendered.append(literal)
|
||||
|
||||
expression = match.group(1).strip()
|
||||
if not expression:
|
||||
raise FlowExpressionError("empty CEL expression in with block")
|
||||
result = _eval_cel(expression, context)
|
||||
rendered.append(result if isinstance(result, str) else json.dumps(result))
|
||||
position = end
|
||||
|
||||
literal = value[position:]
|
||||
rendered.append(literal)
|
||||
|
||||
return "".join(rendered)
|
||||
|
||||
|
||||
def _raise_for_invalid_interpolation(value: str) -> None:
|
||||
if "${" not in value:
|
||||
return
|
||||
raise FlowExpressionError(
|
||||
"invalid CEL interpolation in with block: expressions must be enclosed "
|
||||
"as ${...} and cannot contain braces"
|
||||
)
|
||||
|
||||
|
||||
def _raise_for_literal_braces(value: str) -> None:
|
||||
if "{" not in value and "}" not in value:
|
||||
return
|
||||
raise FlowExpressionError(
|
||||
"invalid CEL interpolation in with block: expressions must be enclosed "
|
||||
"as ${...} and cannot contain braces"
|
||||
)
|
||||
|
||||
|
||||
def _eval_cel(expression: str, context: dict[str, Any]) -> Any:
|
||||
try:
|
||||
from celpy import Environment
|
||||
from celpy.adapter import CELJSONEncoder, json_to_cel
|
||||
from celpy.evaluation import Context
|
||||
|
||||
environment = Environment()
|
||||
program = environment.program(environment.compile(expression))
|
||||
result = program.evaluate(cast(Context, json_to_cel(context)))
|
||||
return json.loads(json.dumps(result, cls=CELJSONEncoder))
|
||||
except Exception as e:
|
||||
raise FlowExpressionError(
|
||||
f"failed to evaluate CEL expression {expression!r}: {e}"
|
||||
) from e
|
||||
38
lib/crewai/src/crewai/flow/runtime/_refs.py
Normal file
38
lib/crewai/src/crewai/flow/runtime/_refs.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Resolution of ``module:qualname`` refs into live Python objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from operator import attrgetter
|
||||
from typing import Any
|
||||
|
||||
|
||||
class InvalidRefError(ValueError):
|
||||
"""A definition ref that cannot be resolved to a live object."""
|
||||
|
||||
|
||||
def resolve_ref(ref: str, *, field: str) -> Any:
|
||||
"""Import the object a definition's `module:qualname` ref points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidRefError(
|
||||
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
|
||||
)
|
||||
try:
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
|
||||
|
||||
|
||||
def resolve_instance_ref(ref: str, *, field: str) -> Any:
|
||||
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
|
||||
target = resolve_ref(ref, field=field)
|
||||
if not inspect.isclass(target):
|
||||
return target
|
||||
try:
|
||||
return target()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
|
||||
) from e
|
||||
@@ -890,41 +890,17 @@ class BaseLLM(BaseModel, ABC):
|
||||
Args:
|
||||
usage_data: Token usage data from the API response
|
||||
"""
|
||||
prompt_tokens = (
|
||||
usage_data.get("prompt_tokens")
|
||||
or usage_data.get("prompt_token_count")
|
||||
or usage_data.get("input_tokens")
|
||||
or 0
|
||||
)
|
||||
metrics = UsageMetrics.from_provider_dict(usage_data)
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
completion_tokens = (
|
||||
usage_data.get("completion_tokens")
|
||||
or usage_data.get("candidates_token_count")
|
||||
or usage_data.get("output_tokens")
|
||||
or 0
|
||||
)
|
||||
|
||||
cached_tokens = (
|
||||
usage_data.get("cached_tokens")
|
||||
or usage_data.get("cached_prompt_tokens")
|
||||
or usage_data.get("cache_read_input_tokens")
|
||||
or 0
|
||||
)
|
||||
if not cached_tokens:
|
||||
prompt_details = usage_data.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = prompt_details.get("cached_tokens", 0) or 0
|
||||
|
||||
reasoning_tokens = usage_data.get("reasoning_tokens", 0) or 0
|
||||
cache_creation_tokens = usage_data.get("cache_creation_tokens", 0) or 0
|
||||
|
||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
||||
self._token_usage["completion_tokens"] += completion_tokens
|
||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
self._token_usage["reasoning_tokens"] += reasoning_tokens
|
||||
self._token_usage["cache_creation_tokens"] += cache_creation_tokens
|
||||
self._token_usage["prompt_tokens"] += metrics.prompt_tokens
|
||||
self._token_usage["completion_tokens"] += metrics.completion_tokens
|
||||
self._token_usage["total_tokens"] += metrics.total_tokens
|
||||
self._token_usage["successful_requests"] += metrics.successful_requests
|
||||
self._token_usage["cached_prompt_tokens"] += metrics.cached_prompt_tokens
|
||||
self._token_usage["reasoning_tokens"] += metrics.reasoning_tokens
|
||||
self._token_usage["cache_creation_tokens"] += metrics.cache_creation_tokens
|
||||
|
||||
def get_token_usage_summary(self) -> UsageMetrics:
|
||||
"""Get summary of token usage for this LLM instance.
|
||||
|
||||
@@ -4,10 +4,31 @@ This module provides models for tracking token usage and request metrics
|
||||
during crew and agent execution.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
def _coerce_int(value: Any) -> int:
|
||||
if value is None:
|
||||
return 0
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _first_int(usage_data: dict[str, Any], *keys: str) -> int:
|
||||
"""Return the first integer-coercible value from ``usage_data`` under any
|
||||
of ``keys``. Falls back to ``0`` when nothing matches."""
|
||||
for key in keys:
|
||||
coerced = _coerce_int(usage_data.get(key))
|
||||
if coerced:
|
||||
return coerced
|
||||
return 0
|
||||
|
||||
|
||||
class UsageMetrics(BaseModel):
|
||||
"""Track usage metrics for crew execution.
|
||||
|
||||
@@ -54,3 +75,50 @@ class UsageMetrics(BaseModel):
|
||||
self.reasoning_tokens += usage_metrics.reasoning_tokens
|
||||
self.cache_creation_tokens += usage_metrics.cache_creation_tokens
|
||||
self.successful_requests += usage_metrics.successful_requests
|
||||
|
||||
@classmethod
|
||||
def from_provider_dict(cls, usage_data: dict[str, Any] | None) -> Self | None:
|
||||
"""Normalize a provider's raw usage dict into a ``UsageMetrics``.
|
||||
|
||||
Accepts the full set of key aliases CrewAI providers emit:
|
||||
``prompt_tokens`` / ``prompt_token_count`` (Gemini) / ``input_tokens``
|
||||
(Anthropic), and the equivalent completion / cached-prompt aliases.
|
||||
Mirrors ``BaseLLM._track_token_usage_internal`` so per-LLM totals,
|
||||
flow-level aggregation, and OTel spans agree on every provider.
|
||||
|
||||
Returns ``None`` for missing/empty input so callers can decide
|
||||
whether to skip the event entirely or treat it as a zero-token
|
||||
successful request.
|
||||
"""
|
||||
if not usage_data:
|
||||
return None
|
||||
|
||||
prompt_tokens = _first_int(
|
||||
usage_data, "prompt_tokens", "prompt_token_count", "input_tokens"
|
||||
)
|
||||
completion_tokens = _first_int(
|
||||
usage_data,
|
||||
"completion_tokens",
|
||||
"candidates_token_count",
|
||||
"output_tokens",
|
||||
)
|
||||
cached_prompt_tokens = _first_int(
|
||||
usage_data,
|
||||
"cached_tokens",
|
||||
"cached_prompt_tokens",
|
||||
"cache_read_input_tokens",
|
||||
)
|
||||
if not cached_prompt_tokens:
|
||||
details = usage_data.get("prompt_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
cached_prompt_tokens = _coerce_int(details.get("cached_tokens"))
|
||||
|
||||
return cls(
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
reasoning_tokens=_coerce_int(usage_data.get("reasoning_tokens")),
|
||||
cache_creation_tokens=_coerce_int(usage_data.get("cache_creation_tokens")),
|
||||
successful_requests=1,
|
||||
)
|
||||
|
||||
@@ -999,7 +999,11 @@ def _json_schema_to_pydantic_field(
|
||||
if examples:
|
||||
schema_extra["examples"] = examples
|
||||
|
||||
default = ... if is_required else None
|
||||
default = (
|
||||
json_schema["default"]
|
||||
if "default" in json_schema
|
||||
else (... if is_required else None)
|
||||
)
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, (int, float)):
|
||||
if "minimum" in json_schema:
|
||||
|
||||
@@ -21,7 +21,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow, start, listen, human_feedback
|
||||
from crewai.flow import Flow, HumanFeedbackResult, start, listen, human_feedback
|
||||
from crewai.flow.async_feedback import (
|
||||
ConsoleProvider,
|
||||
HumanFeedbackPending,
|
||||
@@ -615,6 +615,45 @@ class TestFlowResumeWithFeedback:
|
||||
|
||||
assert persistence.load_pending_feedback("resume-test-123") is None
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_terminal_resume_without_emit_returns_feedback_result(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Terminal resumed non-emit methods return the full feedback result."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(message="Review this:", metadata={"stage": "draft"})
|
||||
def generate(self):
|
||||
return {"content": "generated content"}
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="terminal-non-emit-test-123",
|
||||
flow_class="test.TestFlow",
|
||||
method_name="generate",
|
||||
method_output={"content": "generated content"},
|
||||
message="Review this:",
|
||||
metadata={"stage": "draft"},
|
||||
)
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid="terminal-non-emit-test-123",
|
||||
context=context,
|
||||
state_data={"id": "terminal-non-emit-test-123"},
|
||||
)
|
||||
|
||||
flow = TestFlow.from_pending("terminal-non-emit-test-123", persistence)
|
||||
result = flow.resume("looks good!")
|
||||
|
||||
assert isinstance(result, HumanFeedbackResult)
|
||||
assert result.output == {"content": "generated content"}
|
||||
assert result.feedback == "looks good!"
|
||||
assert result.outcome is None
|
||||
assert result.metadata == {"stage": "draft"}
|
||||
assert flow.method_outputs == [result]
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_resume_routing(self, mock_emit: MagicMock) -> None:
|
||||
"""Test resume with routing."""
|
||||
@@ -667,6 +706,93 @@ class TestFlowResumeWithFeedback:
|
||||
assert flow.last_human_feedback.outcome == "approved"
|
||||
assert flow.result_path == "approved"
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_terminal_resume_with_emit_returns_method_output(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Terminal resumed emit methods return the original method output."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
method_output = {"content": "original content", "status": "ready"}
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review(self):
|
||||
return method_output
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="terminal-route-test-123",
|
||||
flow_class="test.TestFlow",
|
||||
method_name="review",
|
||||
method_output=method_output,
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid="terminal-route-test-123",
|
||||
context=context,
|
||||
state_data={"id": "terminal-route-test-123"},
|
||||
)
|
||||
|
||||
flow = TestFlow.from_pending("terminal-route-test-123", persistence)
|
||||
|
||||
with patch.object(flow, "_collapse_to_outcome", return_value="approved"):
|
||||
result = flow.resume("yes, this looks great")
|
||||
|
||||
assert result == method_output
|
||||
assert flow.method_outputs == [method_output]
|
||||
assert flow.last_human_feedback.outcome == "approved"
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_resume_records_method_output_before_downstream_listeners(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Downstream listeners can read outputs from the resumed method."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(message="Review:")
|
||||
def review(self):
|
||||
return "generated content"
|
||||
|
||||
@listen(review)
|
||||
def downstream(self, result):
|
||||
self.state["seen_outputs"] = self.method_outputs
|
||||
return f"downstream:{result.output}"
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="listener-output-test-123",
|
||||
flow_class="test.TestFlow",
|
||||
method_name="review",
|
||||
method_output="generated content",
|
||||
message="Review:",
|
||||
)
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid="listener-output-test-123",
|
||||
context=context,
|
||||
state_data={"id": "listener-output-test-123"},
|
||||
)
|
||||
|
||||
flow = TestFlow.from_pending("listener-output-test-123", persistence)
|
||||
result = flow.resume("looks good")
|
||||
|
||||
assert result == "downstream:generated content"
|
||||
assert len(flow.state["seen_outputs"]) == 1
|
||||
seen_output = flow.state["seen_outputs"][0]
|
||||
assert isinstance(seen_output, HumanFeedbackResult)
|
||||
assert seen_output.output == "generated content"
|
||||
assert seen_output.feedback == "looks good"
|
||||
|
||||
|
||||
# Integration Tests with @human_feedback decorator
|
||||
|
||||
@@ -1168,132 +1294,13 @@ class TestAsyncHumanFeedbackEdgeCases:
|
||||
|
||||
|
||||
|
||||
class TestLiveLLMPreservationOnResume:
|
||||
"""Tests for preserving the full LLM config across HITL resume."""
|
||||
|
||||
def test_human_feedback_llm_attribute_set_on_wrapper_with_basellm(self) -> None:
|
||||
"""Test that _human_feedback_llm is set on the wrapper when llm is a BaseLLM instance."""
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.model = "gemini/gemini-3-flash"
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=mock_llm,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
flow = TestFlow()
|
||||
method = flow._methods.get("review")
|
||||
assert method is not None
|
||||
assert hasattr(method, "_human_feedback_llm")
|
||||
assert method._human_feedback_llm is mock_llm
|
||||
|
||||
def test_human_feedback_llm_attribute_set_on_wrapper_with_string(self) -> None:
|
||||
"""Test that _human_feedback_llm is set on the wrapper even when llm is a string."""
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
flow = TestFlow()
|
||||
method = flow._methods.get("review")
|
||||
assert method is not None
|
||||
assert hasattr(method, "_human_feedback_llm")
|
||||
assert method._human_feedback_llm == "gpt-4o-mini"
|
||||
class TestResumeLLMFromSerializedContext:
|
||||
"""Resume rebuilds the collapse LLM from the serialized context alone."""
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_resume_async_uses_live_basellm_over_serialized_string(
|
||||
def test_resume_builds_llm_from_serialized_context(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Test that resume_async uses the live BaseLLM from decorator instead of serialized string.
|
||||
|
||||
This is the main bug fix: when a flow resumes, it should use the fully-configured
|
||||
LLM from the re-imported decorator (with credentials, project, etc.) instead of
|
||||
creating a new LLM from just the model string.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
# Create a mock BaseLLM with full config (simulating Gemini with service account)
|
||||
live_llm = MagicMock(spec=BaseLLM)
|
||||
live_llm.model = "gemini/gemini-3-flash"
|
||||
|
||||
class TestFlow(Flow):
|
||||
result_path: str = ""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm=live_llm,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
@listen("approved")
|
||||
def handle_approved(self):
|
||||
self.result_path = "approved"
|
||||
return "Approved!"
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="live-llm-test",
|
||||
flow_class="TestFlow",
|
||||
method_name="review",
|
||||
method_output="content",
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gemini/gemini-3-flash", # Serialized string, NOT the live object
|
||||
)
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid="live-llm-test",
|
||||
context=context,
|
||||
state_data={"id": "live-llm-test"},
|
||||
)
|
||||
|
||||
flow = TestFlow.from_pending("live-llm-test", persistence)
|
||||
|
||||
captured_llm = []
|
||||
|
||||
def capture_llm(feedback, outcomes, llm):
|
||||
captured_llm.append(llm)
|
||||
return "approved"
|
||||
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# NOT the serialized string. The live_llm was captured at class definition
|
||||
# time and stored on the method wrapper as _human_feedback_llm.
|
||||
assert len(captured_llm) == 1
|
||||
# (which is stored on the method's _human_feedback_llm attribute)
|
||||
method = flow._methods.get("review")
|
||||
assert method is not None
|
||||
assert captured_llm[0] is method._human_feedback_llm
|
||||
# And verify it's a BaseLLM instance, not a string
|
||||
assert isinstance(captured_llm[0], BaseLLM)
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_resume_async_falls_back_to_serialized_string_when_no_human_feedback_llm(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Test that resume_async falls back to context.llm when _human_feedback_llm is not available.
|
||||
|
||||
This ensures backward compatibility with flows that were paused before this fix.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
@@ -1325,11 +1332,6 @@ class TestLiveLLMPreservationOnResume:
|
||||
|
||||
flow = TestFlow.from_pending("fallback-test", persistence)
|
||||
|
||||
# Remove _human_feedback_llm to simulate old decorator without this attribute
|
||||
method = flow._methods.get("review")
|
||||
if hasattr(method, "_human_feedback_llm"):
|
||||
delattr(method, "_human_feedback_llm")
|
||||
|
||||
captured_llm = []
|
||||
|
||||
def capture_llm(feedback, outcomes, llm):
|
||||
@@ -1343,85 +1345,3 @@ class TestLiveLLMPreservationOnResume:
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
@patch("crewai.flow.runtime.crewai_event_bus.emit")
|
||||
def test_resume_async_uses_string_from_context_when_human_feedback_llm_is_string(
|
||||
self, mock_emit: MagicMock
|
||||
) -> None:
|
||||
"""Test that when _human_feedback_llm is a string (not BaseLLM), we still use context.llm.
|
||||
|
||||
String LLM values offer no benefit over the serialized context.llm,
|
||||
so we don't prefer them.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="string-llm-test",
|
||||
flow_class="TestFlow",
|
||||
method_name="review",
|
||||
method_output="content",
|
||||
message="Approve?",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid="string-llm-test",
|
||||
context=context,
|
||||
state_data={"id": "string-llm-test"},
|
||||
)
|
||||
|
||||
flow = TestFlow.from_pending("string-llm-test", persistence)
|
||||
|
||||
method = flow._methods.get("review")
|
||||
assert method._human_feedback_llm == "gpt-4o-mini"
|
||||
|
||||
captured_llm = []
|
||||
|
||||
def capture_llm(feedback, outcomes, llm):
|
||||
captured_llm.append(llm)
|
||||
return "approved"
|
||||
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# _human_feedback_llm is a string, so resume deserializes context.llm into an LLM instance
|
||||
assert len(captured_llm) == 1
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
def test_human_feedback_llm_set_for_async_wrapper(self) -> None:
|
||||
"""Test that _human_feedback_llm is set on async wrapper functions."""
|
||||
import asyncio
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.model = "gemini/gemini-3-flash"
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=mock_llm,
|
||||
)
|
||||
async def async_review(self):
|
||||
return "content"
|
||||
|
||||
flow = TestFlow()
|
||||
method = flow._methods.get("async_review")
|
||||
assert method is not None
|
||||
assert hasattr(method, "_human_feedback_llm")
|
||||
assert method._human_feedback_llm is mock_llm
|
||||
|
||||
@@ -617,6 +617,44 @@ class TestKickoffFromCheckpoint:
|
||||
|
||||
|
||||
|
||||
class TestLegacyMethodOutputsRestore:
|
||||
def test_restore_wraps_legacy_plain_value_outputs(self) -> None:
|
||||
flow = Flow()
|
||||
flow._method_outputs = ["first", "second"]
|
||||
state = RuntimeState(root=[flow])
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
loc = state.checkpoint(d)
|
||||
cfg = CheckpointConfig(restore_from=loc)
|
||||
restored = Flow.from_checkpoint(cfg)
|
||||
|
||||
assert restored.method_outputs == ["first", "second"]
|
||||
|
||||
def test_restore_legacy_outputs_evaluates_expressions(self) -> None:
|
||||
from crewai.flow.runtime._expressions import _expression_context
|
||||
|
||||
flow = Flow()
|
||||
flow._method_outputs = ["legacy"]
|
||||
state = RuntimeState(root=[flow])
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
loc = state.checkpoint(d)
|
||||
cfg = CheckpointConfig(restore_from=loc)
|
||||
restored = Flow.from_checkpoint(cfg)
|
||||
|
||||
context = _expression_context(restored)
|
||||
assert context["outputs"] == {"": "legacy"}
|
||||
|
||||
def test_raw_legacy_outputs_remain_readable(self) -> None:
|
||||
from crewai.flow.runtime._expressions import _expression_context
|
||||
|
||||
flow = Flow()
|
||||
flow._method_outputs = ["legacy"]
|
||||
|
||||
assert flow.method_outputs == ["legacy"]
|
||||
assert _expression_context(flow)["outputs"] == {"": "legacy"}
|
||||
|
||||
|
||||
class TestAgentCheckpoint:
|
||||
def _make_agent_state(self) -> RuntimeState:
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
|
||||
@@ -1157,6 +1157,26 @@ def test_flow_name():
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_flow_custom_name_overrides_class_name_in_events():
|
||||
class InternalFlowClass(Flow):
|
||||
name = "PublicName"
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "done"
|
||||
|
||||
received = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle(source, event):
|
||||
received.append(event)
|
||||
|
||||
InternalFlowClass().kickoff()
|
||||
|
||||
assert received[0].flow_name == "PublicName"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
|
||||
@@ -36,16 +36,22 @@ def test_flow_public_exports_are_explicit():
|
||||
"start",
|
||||
}
|
||||
assert set(flow_definition.__all__) == {
|
||||
"FlowActionDefinition",
|
||||
"FlowCodeActionDefinition",
|
||||
"FlowConfigDefinition",
|
||||
"FlowConversationalDefinition",
|
||||
"FlowConversationalRouterDefinition",
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
"FlowPersistenceDefinition",
|
||||
"FlowStateDefinition",
|
||||
"FlowToolActionDefinition",
|
||||
}
|
||||
assert "build_flow_structure" in flow_visualization.__all__
|
||||
assert "calculate_node_levels" not in flow_visualization.__all__
|
||||
@@ -428,6 +434,73 @@ def test_flow_definition_round_trips_json_and_yaml():
|
||||
assert yaml_round_trip.methods["decide"].listen == "begin"
|
||||
|
||||
|
||||
def test_each_action_round_trips_json_and_yaml():
|
||||
definition = flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"description": "Process every loaded row.",
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": [
|
||||
{
|
||||
"normalize": {
|
||||
"call": "tool",
|
||||
"ref": "my_tools:NormalizeRowTool",
|
||||
"with": {"row": "${ item }"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"save": {
|
||||
"call": "code",
|
||||
"ref": "my_flow:save_row",
|
||||
"with": {
|
||||
"row": "${ item }",
|
||||
"normalized": "${ outputs.normalize }",
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
json_round_trip = flow_definition.FlowDefinition.from_json(definition.to_json())
|
||||
yaml_round_trip = flow_definition.FlowDefinition.from_yaml(definition.to_yaml())
|
||||
|
||||
assert json_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.methods["process_rows"].description == (
|
||||
"Process every loaded row."
|
||||
)
|
||||
assert yaml_round_trip.methods["process_rows"].do.call == "each"
|
||||
|
||||
|
||||
def test_flow_definition_rejects_invalid_method_names():
|
||||
with pytest.raises(ValueError, match="Flow method names must match"):
|
||||
flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "InvalidMethodNameFlow",
|
||||
"methods": {
|
||||
"process-rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "expression",
|
||||
"expr": "'done'",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_flow_definition_detects_persist_metadata():
|
||||
@persist(verbose=True)
|
||||
class PersistedFlow(Flow[dict]):
|
||||
@@ -629,6 +702,7 @@ def test_flow_definition_preserves_diagnostics_loaded_from_contract():
|
||||
"name": "LoadedDiagnosticsFlow",
|
||||
"methods": {
|
||||
"decision": {
|
||||
"do": {"ref": "loaded_flows:LoadedDiagnosticsFlow.decision"},
|
||||
"router": True,
|
||||
"emit": ["continue"],
|
||||
}
|
||||
@@ -662,6 +736,7 @@ def test_router_start_false_without_listen_reports_missing_trigger():
|
||||
"name": "LoadedFlow",
|
||||
"methods": {
|
||||
"decision": {
|
||||
"do": {"ref": "loaded_flows:LoadedFlow.decision"},
|
||||
"router": True,
|
||||
"start": False,
|
||||
"emit": ["continue"],
|
||||
@@ -740,8 +815,14 @@ def test_static_string_listener_is_allowed_by_contract():
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "TypoFlow",
|
||||
"methods": {
|
||||
"begin": {"start": True},
|
||||
"handle": {"listen": "begni"},
|
||||
"begin": {
|
||||
"do": {"ref": "loaded_flows:TypoFlow.begin"},
|
||||
"start": True,
|
||||
},
|
||||
"handle": {
|
||||
"do": {"ref": "loaded_flows:TypoFlow.handle"},
|
||||
"listen": "begni",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -754,8 +835,15 @@ def test_start_false_not_classified_as_start_method():
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "ExplicitNonStartFlow",
|
||||
"methods": {
|
||||
"begin": {"start": True},
|
||||
"handle": {"start": False, "listen": "begin"},
|
||||
"begin": {
|
||||
"do": {"ref": "loaded_flows:ExplicitNonStartFlow.begin"},
|
||||
"start": True,
|
||||
},
|
||||
"handle": {
|
||||
"do": {"ref": "loaded_flows:ExplicitNonStartFlow.handle"},
|
||||
"start": False,
|
||||
"listen": "begin",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -812,6 +900,7 @@ def test_flow_definition_logs_diagnostics_when_loaded_from_contract(caplog):
|
||||
"name": "LoadedFlow",
|
||||
"methods": {
|
||||
"decision": {
|
||||
"do": {"ref": "loaded_flows:LoadedFlow.decision"},
|
||||
"router": True,
|
||||
"emit": ["continue"],
|
||||
}
|
||||
|
||||
2057
lib/crewai/tests/test_flow_from_definition.py
Normal file
2057
lib/crewai/tests/test_flow_from_definition.py
Normal file
File diff suppressed because it is too large
Load Diff
511
lib/crewai/tests/test_flow_usage_metrics.py
Normal file
511
lib/crewai/tests/test_flow_usage_metrics.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""Tests for flow-level token usage aggregation
|
||||
|
||||
``flow.usage_metrics`` listens to ``LLMCallCompletedEvent`` for the duration
|
||||
of ``kickoff_async`` so it covers every LLM call inside the flow — crew-led,
|
||||
tool-led, AND bare ``LLM.call(...)`` from a flow method. We exercise the
|
||||
aggregator end-to-end through the real event bus with fabricated events and
|
||||
explicit contextvar control; no live LLM provider is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.flow.flow_context import current_flow_id
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
from crewai.flow.runtime import _usage_dict_to_metrics
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
def _emit_llm_call(
|
||||
*,
|
||||
flow_id: str | None,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
cached_prompt_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Emit one fake ``LLMCallCompletedEvent`` with ``current_flow_id`` pinned
|
||||
to ``flow_id``.
|
||||
|
||||
Runs in a freshly-copied context so the value the bus snapshots at emit
|
||||
time is exactly ``flow_id`` — independent of the calling thread's outer
|
||||
context. Mirrors how the real ``LLM.call`` emits events at runtime.
|
||||
"""
|
||||
usage: dict[str, Any] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
for key, value in (
|
||||
("cached_prompt_tokens", cached_prompt_tokens),
|
||||
("reasoning_tokens", reasoning_tokens),
|
||||
("cache_creation_tokens", cache_creation_tokens),
|
||||
):
|
||||
if value:
|
||||
usage[key] = value
|
||||
event = LLMCallCompletedEvent(
|
||||
call_id=str(uuid4()),
|
||||
model="gpt-4o-mini",
|
||||
response="ok",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
|
||||
def _emit() -> None:
|
||||
current_flow_id.set(flow_id)
|
||||
future = crewai_event_bus.emit(object(), event)
|
||||
if future is not None:
|
||||
future.result(timeout=5.0)
|
||||
|
||||
ctx.run(_emit)
|
||||
|
||||
|
||||
class _ScriptedFlow(Flow):
|
||||
"""A Flow whose ``@start`` delegates to a per-instance ``_script`` closure.
|
||||
|
||||
Each test attaches a script with ``flow._script = lambda f: ...`` so we
|
||||
don't redefine a Flow subclass for every scenario.
|
||||
"""
|
||||
|
||||
@start()
|
||||
def run(self) -> None:
|
||||
script: Callable[[Flow], None] = getattr(self, "_script", lambda _f: None)
|
||||
script(self)
|
||||
|
||||
|
||||
def _run(script: Callable[[Flow], None] = lambda _f: None) -> Flow:
|
||||
"""Build a ``_ScriptedFlow``, attach ``script``, kickoff. Returns the flow."""
|
||||
flow = _ScriptedFlow()
|
||||
flow._script = script
|
||||
flow.kickoff()
|
||||
return flow
|
||||
|
||||
|
||||
class TestUsageDictToMetrics:
|
||||
"""Unit tests for the dict-to-UsageMetrics normalizer."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"usage, expected",
|
||||
[
|
||||
(None, None),
|
||||
({}, None),
|
||||
(
|
||||
{"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
UsageMetrics(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
total_tokens=30,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# total_tokens missing → derived from prompt + completion
|
||||
(
|
||||
{"prompt_tokens": 4, "completion_tokens": 6},
|
||||
UsageMetrics(
|
||||
prompt_tokens=4,
|
||||
completion_tokens=6,
|
||||
total_tokens=10,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# Extended provider-specific keys flow through normalization
|
||||
(
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 80,
|
||||
"total_tokens": 180,
|
||||
"cached_prompt_tokens": 40,
|
||||
"reasoning_tokens": 25,
|
||||
"cache_creation_tokens": 10,
|
||||
},
|
||||
UsageMetrics(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=80,
|
||||
total_tokens=180,
|
||||
cached_prompt_tokens=40,
|
||||
reasoning_tokens=25,
|
||||
cache_creation_tokens=10,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# Garbage / non-int values coerce to 0 instead of crashing
|
||||
(
|
||||
{"prompt_tokens": "n/a", "completion_tokens": None, "total_tokens": 7},
|
||||
UsageMetrics(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# Native Anthropic provider emits input_tokens/output_tokens
|
||||
(
|
||||
{"input_tokens": 12, "output_tokens": 8},
|
||||
UsageMetrics(
|
||||
prompt_tokens=12,
|
||||
completion_tokens=8,
|
||||
total_tokens=20,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# Native Gemini provider emits prompt_token_count/candidates_token_count
|
||||
(
|
||||
{
|
||||
"prompt_token_count": 30,
|
||||
"candidates_token_count": 20,
|
||||
"reasoning_tokens": 5,
|
||||
},
|
||||
UsageMetrics(
|
||||
prompt_tokens=30,
|
||||
completion_tokens=20,
|
||||
total_tokens=50,
|
||||
reasoning_tokens=5,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
# OpenAI nests cached_tokens under prompt_tokens_details
|
||||
(
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"prompt_tokens_details": {"cached_tokens": 30},
|
||||
},
|
||||
UsageMetrics(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cached_prompt_tokens=30,
|
||||
successful_requests=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"none",
|
||||
"empty",
|
||||
"all_keys",
|
||||
"no_total",
|
||||
"extended_keys",
|
||||
"garbage",
|
||||
"anthropic_aliases",
|
||||
"gemini_aliases",
|
||||
"openai_nested_cached",
|
||||
],
|
||||
)
|
||||
def test_normalization(
|
||||
self, usage: dict[str, Any] | None, expected: UsageMetrics | None
|
||||
) -> None:
|
||||
assert _usage_dict_to_metrics(usage) == expected
|
||||
|
||||
|
||||
class TestFlowUsageAggregation:
|
||||
"""End-to-end tests driving the listener through the real event bus."""
|
||||
|
||||
def test_sums_every_llm_call_in_the_flow(self) -> None:
|
||||
"""Multiple LLM calls — including bare ``LLM.call(...)`` made outside
|
||||
any crew — accumulate; ``successful_requests`` tracks the call count."""
|
||||
|
||||
def script(flow: Flow) -> None:
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=300, completion_tokens=300)
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=200, completion_tokens=100)
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=20, completion_tokens=20)
|
||||
|
||||
flow = _run(script)
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 940
|
||||
assert flow.usage_metrics.prompt_tokens == 520
|
||||
assert flow.usage_metrics.completion_tokens == 420
|
||||
assert flow.usage_metrics.successful_requests == 3
|
||||
|
||||
def test_returns_zero_when_no_calls_happen(self) -> None:
|
||||
flow = _run()
|
||||
assert flow.usage_metrics == UsageMetrics()
|
||||
|
||||
def test_ignores_events_from_other_flows(self) -> None:
|
||||
"""Concurrent flow runs share the singleton bus, so the listener must
|
||||
scope itself to its own flow via the contextvar match."""
|
||||
|
||||
def script(flow: Flow) -> None:
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=50, completion_tokens=50)
|
||||
_emit_llm_call(flow_id="some-other-flow", prompt_tokens=49_000, completion_tokens=50_999)
|
||||
|
||||
flow = _run(script)
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 100
|
||||
assert flow.usage_metrics.successful_requests == 1
|
||||
|
||||
def test_resets_between_kickoffs(self) -> None:
|
||||
flow = _ScriptedFlow()
|
||||
flow._script = lambda f: _emit_llm_call(
|
||||
flow_id=f._flow_match_id, prompt_tokens=250, completion_tokens=250
|
||||
)
|
||||
|
||||
flow.kickoff()
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 500
|
||||
assert flow.usage_metrics.successful_requests == 1
|
||||
|
||||
def test_usage_metrics_returns_independent_copy(self) -> None:
|
||||
"""``usage_metrics`` must return a copy, not the internal instance —
|
||||
otherwise callers can clobber the in-flight accumulator."""
|
||||
|
||||
flow = _run(
|
||||
lambda f: _emit_llm_call(
|
||||
flow_id=f._flow_match_id, prompt_tokens=50, completion_tokens=50
|
||||
)
|
||||
)
|
||||
|
||||
snapshot = flow.usage_metrics
|
||||
snapshot.total_tokens = 999_999
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 100
|
||||
|
||||
def test_handler_is_unregistered_after_kickoff(self) -> None:
|
||||
"""Long-lived workers (Celery, devkit) must not leak one handler per
|
||||
kickoff on the singleton bus, on either the success or failure path."""
|
||||
|
||||
def handler_count() -> int:
|
||||
return len(
|
||||
crewai_event_bus._sync_handlers.get(LLMCallCompletedEvent, frozenset())
|
||||
)
|
||||
|
||||
before = handler_count()
|
||||
|
||||
flow = _ScriptedFlow()
|
||||
flow._script = lambda f: _emit_llm_call(
|
||||
flow_id=f._flow_match_id, prompt_tokens=5, completion_tokens=5
|
||||
)
|
||||
for _ in range(3):
|
||||
flow.kickoff()
|
||||
|
||||
assert handler_count() == before
|
||||
|
||||
def boom(_f: Flow) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
failing = _ScriptedFlow()
|
||||
failing._script = boom
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
failing.kickoff()
|
||||
|
||||
assert handler_count() == before
|
||||
|
||||
def test_kickoff_flushes_event_bus_before_returning(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""`kickoff_async` must drain pending LLMCallCompletedEvent handlers
|
||||
before detaching the listener — otherwise late handlers landing on
|
||||
the threadpool would be lost on short flows. Mirrors the flush
|
||||
``Crew.kickoff()`` performs before reporting ``token_usage``."""
|
||||
|
||||
flush_calls: list[None] = []
|
||||
original_flush = crewai_event_bus.flush
|
||||
|
||||
def tracked_flush(*args: Any, **kwargs: Any) -> bool:
|
||||
flush_calls.append(None)
|
||||
return original_flush(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(crewai_event_bus, "flush", tracked_flush)
|
||||
|
||||
flow = _ScriptedFlow()
|
||||
flow._script = lambda f: _emit_llm_call(
|
||||
flow_id=f._flow_match_id, prompt_tokens=3, completion_tokens=4
|
||||
)
|
||||
flow.kickoff()
|
||||
|
||||
assert flush_calls, "kickoff did not flush the event bus before returning"
|
||||
assert flow.usage_metrics.total_tokens == 7
|
||||
|
||||
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
|
||||
"""A handler still queued from a prior kickoff must not write into
|
||||
a later kickoff's accumulator. The handler's closure captures its
|
||||
own accumulator object, so any late writes land on an orphaned
|
||||
instance and the live ``usage_metrics`` is unaffected."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def script(flow: Flow) -> None:
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
|
||||
captured["handler"] = flow._usage_aggregation_handler
|
||||
captured["match_id"] = flow._flow_match_id
|
||||
|
||||
flow = _run(script)
|
||||
assert flow.usage_metrics.total_tokens == 20
|
||||
|
||||
flow._script = lambda f: None
|
||||
flow.kickoff()
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
stale_handler = captured["handler"]
|
||||
assert stale_handler is not None
|
||||
|
||||
stale_event = LLMCallCompletedEvent(
|
||||
call_id=str(uuid4()),
|
||||
model="gpt-4o-mini",
|
||||
response="ok",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
def test_pause_detaches_listener_and_does_not_leak(self) -> None:
|
||||
"""When ``kickoff_async`` pauses for human feedback, the listener
|
||||
must be detached from the singleton bus to avoid leaking handlers
|
||||
across abandoned paused instances. Pre-pause LLM events still
|
||||
count because the bus snapshots handlers at emit time. Late
|
||||
events emitted after the pause returns do not count for this
|
||||
instance — resume paths re-attach a fresh listener."""
|
||||
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class _PausingFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> None:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
)
|
||||
captured["pre_pause_total"] = self.usage_metrics.total_tokens
|
||||
raise HumanFeedbackPending(
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=self.flow_id,
|
||||
flow_class="_PausingFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
)
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow = _PausingFlow(persistence=persistence)
|
||||
result = flow.kickoff()
|
||||
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
assert captured["pre_pause_total"] == 30
|
||||
assert flow._usage_aggregation_handler is None
|
||||
|
||||
# A late event emitted after the pause does not reach the
|
||||
# detached listener, so the running total is unchanged.
|
||||
_emit_llm_call(
|
||||
flow_id=flow._flow_match_id,
|
||||
prompt_tokens=2,
|
||||
completion_tokens=3,
|
||||
)
|
||||
assert flow.usage_metrics.total_tokens == 30
|
||||
|
||||
def test_aggregates_resume_after_from_pending(self) -> None:
|
||||
"""A flow restored via ``from_pending`` is a fresh instance with no
|
||||
``_flow_match_id``; without seeding it, the listener attached in
|
||||
``resume_async`` either ignores its own LLM calls or absorbs unrelated
|
||||
ones. ``from_pending`` must seed the match id so the resume-phase
|
||||
aggregator counts our own calls and only our own calls."""
|
||||
|
||||
class _ResumeFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "content"
|
||||
|
||||
@listen(begin)
|
||||
def on_begin(self, _feedback: Any) -> str:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
_emit_llm_call(
|
||||
flow_id="some-other-flow",
|
||||
prompt_tokens=9_999,
|
||||
completion_tokens=9_999,
|
||||
)
|
||||
return "done"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow_id = "usage-resume-test"
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid=flow_id,
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=flow_id,
|
||||
flow_class="_ResumeFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
),
|
||||
state_data={"id": flow_id},
|
||||
)
|
||||
|
||||
flow = _ResumeFlow.from_pending(flow_id, persistence)
|
||||
assert flow._flow_match_id == flow.flow_id
|
||||
|
||||
flow.resume("ok")
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 150
|
||||
assert flow.usage_metrics.prompt_tokens == 100
|
||||
assert flow.usage_metrics.completion_tokens == 50
|
||||
assert flow.usage_metrics.successful_requests == 1
|
||||
|
||||
def test_resume_aggregates_under_foreign_flow_context(self) -> None:
|
||||
"""Resume must override an already-set ``current_flow_id`` so its
|
||||
own LLM events match the listener's filter even when invoked from
|
||||
inside another flow's active context."""
|
||||
|
||||
class _ResumeFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "content"
|
||||
|
||||
@listen(begin)
|
||||
def on_begin(self, _feedback: Any) -> str:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=42,
|
||||
completion_tokens=8,
|
||||
)
|
||||
return "done"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow_id = "resume-foreign-context"
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid=flow_id,
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=flow_id,
|
||||
flow_class="_ResumeFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
),
|
||||
state_data={"id": flow_id},
|
||||
)
|
||||
|
||||
foreign_token = current_flow_id.set("some-parent-flow")
|
||||
try:
|
||||
flow = _ResumeFlow.from_pending(flow_id, persistence)
|
||||
flow.resume("ok")
|
||||
finally:
|
||||
current_flow_id.reset(foreign_token)
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 50
|
||||
assert flow.usage_metrics.successful_requests == 1
|
||||
@@ -77,12 +77,22 @@ class ComplexFlow(Flow):
|
||||
return "complete"
|
||||
|
||||
|
||||
def _attach_flow_definition(flow_class: type[Flow], methods: dict[str, object]) -> None:
|
||||
def _attach_flow_definition(
|
||||
flow_class: type[Flow], methods: dict[str, dict[str, object]]
|
||||
) -> None:
|
||||
flow_class._flow_definition = FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": flow_class.__name__,
|
||||
"methods": methods,
|
||||
"methods": {
|
||||
name: {
|
||||
"do": {
|
||||
"ref": f"{flow_class.__module__}:{flow_class.__name__}.{name}"
|
||||
},
|
||||
**spec,
|
||||
}
|
||||
for name, spec in methods.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -125,13 +135,20 @@ def test_build_flow_structure_from_flow_definition():
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "DefinedFlow",
|
||||
"methods": {
|
||||
"begin": {"start": True},
|
||||
"begin": {
|
||||
"do": {"ref": "defined_flows:DefinedFlow.begin"},
|
||||
"start": True,
|
||||
},
|
||||
"decide": {
|
||||
"do": {"ref": "defined_flows:DefinedFlow.decide"},
|
||||
"listen": "begin",
|
||||
"router": True,
|
||||
"emit": ["done"],
|
||||
},
|
||||
"finish": {"listen": "done"},
|
||||
"finish": {
|
||||
"do": {"ref": "defined_flows:DefinedFlow.finish"},
|
||||
"listen": "done",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -92,8 +92,8 @@ class TestHumanFeedbackValidation:
|
||||
assert hasattr(test_method, "__human_feedback_config__")
|
||||
assert not hasattr(test_method, "__is_router__")
|
||||
|
||||
def test_persist_preserves_human_feedback_llm_attribute(self):
|
||||
"""Test @persist preserves the live LLM stashed by @human_feedback."""
|
||||
def test_persist_preserves_human_feedback_config(self):
|
||||
"""Test @persist preserves the config stamped by @human_feedback."""
|
||||
llm = object()
|
||||
|
||||
@persist()
|
||||
@@ -105,8 +105,8 @@ class TestHumanFeedbackValidation:
|
||||
def test_method(self):
|
||||
return "output"
|
||||
|
||||
assert hasattr(test_method, "_human_feedback_llm")
|
||||
assert test_method._human_feedback_llm is llm
|
||||
assert hasattr(test_method, "__human_feedback_config__")
|
||||
assert test_method.__human_feedback_config__.llm is llm
|
||||
|
||||
|
||||
class TestHumanFeedbackConfig:
|
||||
@@ -481,7 +481,7 @@ class TestHumanFeedbackLearn:
|
||||
with patch.object(
|
||||
flow, "_request_human_feedback", return_value="looks good"
|
||||
):
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
# memory.recall and memory.remember_many should NOT be called
|
||||
flow.memory.recall.assert_not_called()
|
||||
@@ -516,7 +516,7 @@ class TestHumanFeedbackLearn:
|
||||
)
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
# remember_many should be called with the distilled lesson
|
||||
flow.memory.remember_many.assert_called_once()
|
||||
@@ -551,7 +551,7 @@ class TestHumanFeedbackLearn:
|
||||
|
||||
captured_output = {}
|
||||
|
||||
def capture_feedback(message, output, metadata=None, emit=None):
|
||||
def capture_feedback(message, output, metadata=None, emit=None, method_name=""):
|
||||
captured_output["shown_to_human"] = output
|
||||
return "approved"
|
||||
|
||||
@@ -570,7 +570,7 @@ class TestHumanFeedbackLearn:
|
||||
]
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
assert captured_output["shown_to_human"] == "draft with citations added"
|
||||
# recall was called to find past lessons
|
||||
@@ -592,7 +592,7 @@ class TestHumanFeedbackLearn:
|
||||
with patch.object(
|
||||
flow, "_request_human_feedback", return_value=""
|
||||
):
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
flow.memory.remember_many.assert_not_called()
|
||||
|
||||
@@ -631,7 +631,7 @@ class TestHumanFeedbackLearn:
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def capture_feedback(message, output, metadata=None, emit=None):
|
||||
def capture_feedback(message, output, metadata=None, emit=None, method_name=""):
|
||||
captured["shown_to_human"] = output
|
||||
return ""
|
||||
|
||||
@@ -645,7 +645,7 @@ class TestHumanFeedbackLearn:
|
||||
mock_llm.call.side_effect = RuntimeError("simulated pre-review failure")
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
assert captured["shown_to_human"] == "raw draft"
|
||||
assert any(
|
||||
@@ -690,7 +690,7 @@ class TestHumanFeedbackLearn:
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
with pytest.raises(RuntimeError, match="simulated pre-review failure"):
|
||||
flow.produce()
|
||||
flow.kickoff()
|
||||
|
||||
def test_distillation_failure_logs_and_does_not_block_flow(self, caplog):
|
||||
"""Distillation LLM failure logs a warning but does not break the flow."""
|
||||
@@ -717,7 +717,7 @@ class TestHumanFeedbackLearn:
|
||||
mock_llm.call.side_effect = RuntimeError("simulated distill failure")
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
flow.produce() # must not raise
|
||||
flow.kickoff() # must not raise
|
||||
|
||||
flow.memory.remember_many.assert_not_called()
|
||||
assert any(
|
||||
@@ -860,9 +860,9 @@ class TestHumanFeedbackFinalOutputPreservation:
|
||||
):
|
||||
flow.kickoff()
|
||||
|
||||
# _method_outputs should contain the real output
|
||||
assert len(flow._method_outputs) == 1
|
||||
assert flow._method_outputs[0] == {"data": "real output"}
|
||||
# method_outputs should contain the real output
|
||||
assert flow.method_outputs == [{"data": "real output"}]
|
||||
assert flow._method_outputs[0]["method"] == "generate"
|
||||
|
||||
@patch("builtins.input", return_value="looks good")
|
||||
@patch("builtins.print")
|
||||
|
||||
@@ -778,77 +778,11 @@ class TestEdgeCases:
|
||||
class TestLLMConfigPreservation:
|
||||
"""Tests that LLM config is preserved through @human_feedback serialization.
|
||||
|
||||
PR #4970 introduced _human_feedback_llm stashing so the live LLM object survives
|
||||
decorator wrapping for same-process resume. The serialization path
|
||||
(_serialize_llm_for_context / _deserialize_llm_from_context) preserves
|
||||
config for cross-process resume.
|
||||
The flow definition keeps the live LLM object for same-process execution.
|
||||
The serialization path (_serialize_llm_for_context /
|
||||
_deserialize_llm_from_context) preserves config for cross-process resume.
|
||||
"""
|
||||
|
||||
def test_human_feedback_llm_stashed_on_wrapper_with_llm_instance(self):
|
||||
"""Test that passing an LLM instance stashes it on the wrapper as _human_feedback_llm."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class ConfigFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ConfigFlow.review
|
||||
assert hasattr(method, "_human_feedback_llm"), "_human_feedback_llm not found on wrapper"
|
||||
assert method._human_feedback_llm is llm_instance, "_human_feedback_llm is not the same object"
|
||||
|
||||
def test_human_feedback_llm_preserved_on_listen_method(self):
|
||||
"""Test that _human_feedback_llm is preserved when @human_feedback is on a @listen method."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.7)
|
||||
|
||||
class ListenConfigFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
return "draft"
|
||||
|
||||
@listen("generate")
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ListenConfigFlow.review
|
||||
assert hasattr(method, "_human_feedback_llm")
|
||||
assert method._human_feedback_llm is llm_instance
|
||||
|
||||
def test_human_feedback_llm_accessible_on_instance(self):
|
||||
"""Test that _human_feedback_llm survives Flow instantiation (bound method access)."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class InstanceFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
flow = InstanceFlow()
|
||||
instance_method = flow.review
|
||||
assert hasattr(instance_method, "_human_feedback_llm")
|
||||
assert instance_method._human_feedback_llm is llm_instance
|
||||
|
||||
def test_serialize_llm_preserves_config_fields(self):
|
||||
"""Test that _serialize_llm_for_context captures temperature, base_url, etc."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
|
||||
Reference in New Issue
Block a user