mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-11 14:28:14 +00:00
Compare commits
13 Commits
main
...
gl/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e869cdc4e5 | ||
|
|
84deb5cc62 | ||
|
|
7d9faa7cbf | ||
|
|
e3ab996893 | ||
|
|
62f3279bc5 | ||
|
|
9682e458d6 | ||
|
|
6df5421785 | ||
|
|
f5116004db | ||
|
|
31b8a0989a | ||
|
|
3e4226268c | ||
|
|
a10ef6e28d | ||
|
|
3dc3f8bb52 | ||
|
|
441e214a00 |
@@ -105,6 +105,9 @@ a2a = [
|
||||
file-processing = [
|
||||
"crewai-files",
|
||||
]
|
||||
pickling = [
|
||||
'cloudpickle~=3.1.2'
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -35,6 +35,7 @@ from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
@@ -65,8 +66,10 @@ from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
set_tracing_enabled,
|
||||
should_enable_tracing,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
@@ -83,7 +86,10 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
@@ -94,6 +100,8 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.memory_tools import create_memory_tools
|
||||
from crewai.types.callable import SerializableCallable
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
@@ -165,12 +173,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_execution_span: Span | None = PrivateAttr(default=None)
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
@@ -188,7 +196,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Any = Field(
|
||||
memory: bool | Memory | MemoryScope | MemorySlice = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
@@ -203,23 +211,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Any | None = Field(
|
||||
task_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
@@ -262,7 +270,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
planning_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
@@ -283,7 +291,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
chat_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -356,12 +364,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def create_crew_memory(self) -> Crew:
|
||||
"""Initialize unified memory, respecting crew embedder config."""
|
||||
if self.memory is True:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
embedder = None
|
||||
if self.embedder is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(self.embedder)
|
||||
self._memory = Memory(embedder=embedder)
|
||||
elif self.memory:
|
||||
@@ -1411,7 +1415,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return tools
|
||||
|
||||
def _add_memory_tools(
|
||||
self, tools: list[BaseTool], memory: Any
|
||||
self, tools: list[BaseTool], memory: Memory | MemoryScope | MemorySlice
|
||||
) -> list[BaseTool]:
|
||||
"""Add recall and remember tools when memory is available.
|
||||
|
||||
@@ -1422,8 +1426,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns:
|
||||
Updated list with memory tools added.
|
||||
"""
|
||||
from crewai.tools.memory_tools import create_memory_tools
|
||||
|
||||
return self._merge_tools(tools, create_memory_tools(memory))
|
||||
|
||||
def _add_file_tools(
|
||||
@@ -2006,11 +2008,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
|
||||
@@ -17,9 +17,12 @@ from collections.abc import (
|
||||
ValuesView,
|
||||
)
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import enum
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import (
|
||||
@@ -49,6 +52,7 @@ from crewai.events.event_context import (
|
||||
reset_last_event_id,
|
||||
triggered_by_scope,
|
||||
)
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
@@ -61,16 +65,27 @@ from crewai.events.listeners.tracing.utils import (
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowInputReceivedEvent,
|
||||
FlowInputRequestedEvent,
|
||||
FlowPausedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
HumanFeedbackReceivedEvent,
|
||||
HumanFeedbackRequestedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionPausedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
|
||||
from crewai.flow.flow_config import flow_config
|
||||
from crewai.flow.flow_context import (
|
||||
current_flow_id,
|
||||
current_flow_method_name,
|
||||
current_flow_request_id,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowConditions,
|
||||
@@ -80,6 +95,9 @@ from crewai.flow.flow_wrappers import (
|
||||
SimpleFlowCondition,
|
||||
StartMethod,
|
||||
)
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult
|
||||
from crewai.flow.input_provider import InputResponse
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import (
|
||||
FlowExecutionData,
|
||||
@@ -98,14 +116,18 @@ from crewai.flow.utils import (
|
||||
is_flow_method_name,
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.flow.input_provider import InputProvider
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
@@ -753,10 +775,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = None
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = (
|
||||
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
||||
)
|
||||
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
||||
memory: Memory | MemoryScope | MemorySlice | None = None
|
||||
input_provider: InputProvider | None = None
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -885,8 +905,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""
|
||||
if self.memory is None:
|
||||
raise ValueError("No memory configured for this flow")
|
||||
if isinstance(content, list):
|
||||
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
if isinstance(content, list) and isinstance(self.memory, Memory):
|
||||
return self.memory.remember_many(content, **kwargs)
|
||||
if isinstance(content, list):
|
||||
return [self.memory.remember(c, **kwargs) for c in content]
|
||||
return self.memory.remember(content, **kwargs)
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
@@ -1115,8 +1140,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
```
|
||||
"""
|
||||
if persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
persistence = SQLiteFlowPersistence()
|
||||
|
||||
# Load pending feedback context and state
|
||||
@@ -1229,10 +1252,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If no pending feedback context exists
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult
|
||||
|
||||
if self._pending_feedback_context is None:
|
||||
raise ValueError(
|
||||
"No pending feedback context. Use from_pending() to restore a paused flow."
|
||||
@@ -1315,13 +1334,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if flow was paused again for human feedback (loop case)
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self._persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
self._persistence = SQLiteFlowPersistence()
|
||||
|
||||
state_data = (
|
||||
@@ -1724,8 +1739,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
# HumanFeedbackPending is expected control flow, not an error
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
result_holder.append(e)
|
||||
else:
|
||||
@@ -1794,8 +1807,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
# HumanFeedbackPending is expected control flow, not an error
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
result_holder.append(e)
|
||||
else:
|
||||
@@ -1920,13 +1931,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
# Check if flow was paused for human feedback
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self._persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
self._persistence = SQLiteFlowPersistence()
|
||||
|
||||
state_data = (
|
||||
@@ -2162,8 +2169,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Set method name in context so ask() can read it without
|
||||
# stack inspection. Must happen before copy_context() so the
|
||||
# value propagates into the thread pool for sync methods.
|
||||
from crewai.flow.flow_context import current_flow_method_name
|
||||
|
||||
method_name_token = current_flow_method_name.set(method_name)
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@@ -2171,8 +2176,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
# Run sync methods in thread pool for isolation
|
||||
# This allows Agent.kickoff() to work synchronously inside Flow methods
|
||||
import contextvars
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
|
||||
finally:
|
||||
@@ -2206,15 +2209,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return result, finished_event_id
|
||||
except Exception as e:
|
||||
# Check if this is a HumanFeedbackPending exception (paused, not failed)
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
e.context.method_name = method_name
|
||||
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self._persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
self._persistence = SQLiteFlowPersistence()
|
||||
|
||||
# Emit paused event (not failed)
|
||||
@@ -2646,8 +2645,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
except Exception as e:
|
||||
# Don't log HumanFeedbackPending as an error - it's expected control flow
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if not isinstance(e, HumanFeedbackPending):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
raise
|
||||
@@ -2665,9 +2662,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
An object implementing the ``InputProvider`` protocol.
|
||||
"""
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
if self.input_provider is not None:
|
||||
return self.input_provider
|
||||
if flow_config.input_provider is not None:
|
||||
@@ -2753,19 +2747,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return topic
|
||||
```
|
||||
"""
|
||||
from concurrent.futures import (
|
||||
ThreadPoolExecutor,
|
||||
TimeoutError as FuturesTimeoutError,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowInputReceivedEvent,
|
||||
FlowInputRequestedEvent,
|
||||
)
|
||||
from crewai.flow.flow_context import current_flow_method_name
|
||||
from crewai.flow.input_provider import InputResponse
|
||||
|
||||
method_name = current_flow_method_name.get("unknown")
|
||||
|
||||
# Emit input requested event
|
||||
@@ -2796,7 +2777,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
try:
|
||||
raw = future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
except TimeoutError:
|
||||
future.cancel()
|
||||
raw = None
|
||||
finally:
|
||||
@@ -2869,12 +2850,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
The human's feedback as a string. Empty string if no feedback provided.
|
||||
"""
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.events.types.flow_events import (
|
||||
HumanFeedbackReceivedEvent,
|
||||
HumanFeedbackRequestedEvent,
|
||||
)
|
||||
|
||||
# Emit feedback requested event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -2948,18 +2923,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
One of the outcome strings that best matches the feedback intent.
|
||||
"""
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
llm_instance: BaseLLMClass
|
||||
llm_instance: BaseLLM
|
||||
if isinstance(llm, str):
|
||||
llm_instance = LLM(model=llm)
|
||||
elif isinstance(llm, BaseLLMClass):
|
||||
elif isinstance(llm, BaseLLM):
|
||||
llm_instance = llm
|
||||
else:
|
||||
raise ValueError(f"Invalid llm type: {type(llm)}. Expected str or BaseLLM.")
|
||||
@@ -2992,8 +2959,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
parsed = json.loads(response)
|
||||
return str(parsed.get("outcome", outcomes[0]))
|
||||
@@ -3058,8 +3023,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
This method uses the centralized Rich console formatter for output
|
||||
and the standard logging module for log level support.
|
||||
"""
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
event_listener.formatter.console.print(message, style=color)
|
||||
if level == "info":
|
||||
logger.info(message)
|
||||
|
||||
@@ -83,6 +83,7 @@ if TYPE_CHECKING:
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction[Any])
|
||||
|
||||
@@ -349,6 +350,10 @@ def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: EmbedderConfig) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
def build_embedder(spec): # type: ignore[no-untyped-def]
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callable import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
@@ -123,8 +124,9 @@ class Task(BaseModel):
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
)
|
||||
callback: Any | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after the task is completed.",
|
||||
)
|
||||
agent: BaseAgent | None = Field(
|
||||
description="Agent responsible for execution the task.", default=None
|
||||
|
||||
96
lib/crewai/src/crewai/types/callable.py
Normal file
96
lib/crewai/src/crewai/types/callable.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Serializable callable type for Pydantic models.
|
||||
|
||||
All callables (ex., named functions, lambdas, closures, methods) are serialized
|
||||
via ``cloudpickle`` + base64. On deserialization the base64 payload is
|
||||
decoded and unpickled back into a live callable.
|
||||
|
||||
Deserialization is **opt-in** to prevent arbitrary code execution from
|
||||
untrusted payloads. Callers must use :data:`allow_pickle_deserialization` to enable it::
|
||||
|
||||
with allow_pickle_deserialization:
|
||||
task = Task.model_validate_json(untrusted_json)
|
||||
|
||||
``cloudpickle`` is an optional dependency. Serialization and deserialization
|
||||
will raise ``RuntimeError`` if it is not installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
|
||||
|
||||
|
||||
_ALLOW_PICKLE: ContextVar[bool] = ContextVar("_ALLOW_PICKLE", default=False)
|
||||
_ALLOW_PICKLE_TOKEN: ContextVar[Token[bool] | None] = ContextVar(
|
||||
"_ALLOW_PICKLE_TOKEN", default=None
|
||||
)
|
||||
|
||||
|
||||
def _import_cloudpickle() -> Any:
|
||||
try:
|
||||
import cloudpickle # type: ignore[import-untyped]
|
||||
except ModuleNotFoundError:
|
||||
raise RuntimeError(
|
||||
"cloudpickle is required for callable serialization. "
|
||||
"Install it with: uv add 'crewai[pickling]'"
|
||||
) from None
|
||||
return cloudpickle
|
||||
|
||||
|
||||
class _AllowPickleDeserialization:
|
||||
"""Reentrant context manager that opts in to cloudpickle deserialization.
|
||||
|
||||
Usage::
|
||||
|
||||
with allow_pickle_deserialization:
|
||||
task = Task.model_validate_json(payload)
|
||||
"""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
_ALLOW_PICKLE_TOKEN.set(_ALLOW_PICKLE.set(True))
|
||||
|
||||
def __exit__(self, *_: object) -> None:
|
||||
token = _ALLOW_PICKLE_TOKEN.get()
|
||||
if token is not None:
|
||||
_ALLOW_PICKLE.reset(token)
|
||||
|
||||
|
||||
allow_pickle_deserialization = _AllowPickleDeserialization()
|
||||
|
||||
|
||||
def _deserialize_callable(v: str | Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""Deserialize a base64-encoded cloudpickle payload, or pass through if already callable."""
|
||||
if isinstance(v, str):
|
||||
if not _ALLOW_PICKLE.get():
|
||||
raise RuntimeError(
|
||||
"Refusing to unpickle a callable from untrusted data. "
|
||||
"Wrap the deserialization call with "
|
||||
"`with allow_pickle_deserialization: ...` "
|
||||
"if you trust the source."
|
||||
)
|
||||
cloudpickle = _import_cloudpickle()
|
||||
obj = cloudpickle.loads(base64.b85decode(v))
|
||||
if not callable(obj):
|
||||
raise ValueError(
|
||||
f"Deserialized object is {type(obj).__name__}, not a callable"
|
||||
)
|
||||
return obj # type: ignore[no-any-return]
|
||||
return v
|
||||
|
||||
|
||||
def _serialize_callable(v: Callable[..., Any]) -> str:
|
||||
"""Serialize any callable to a base64-encoded cloudpickle payload."""
|
||||
cloudpickle = _import_cloudpickle()
|
||||
return base64.b85encode(cloudpickle.dumps(v)).decode("ascii")
|
||||
|
||||
|
||||
SerializableCallable = Annotated[
|
||||
Callable[..., Any],
|
||||
BeforeValidator(_deserialize_callable),
|
||||
PlainSerializer(_serialize_callable, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
@@ -0,0 +1,113 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Based on the following human feedback,
|
||||
determine which outcome best matches their intent.\n\nFeedback: I approve this\n\nPossible
|
||||
outcomes: approved, rejected\n\nRespond with ONLY one of the exact outcome values
|
||||
listed above, nothing else."}],"model":"gpt-4o-mini","response_format":{"type":"json_schema","json_schema":{"schema":{"description":"The
|
||||
outcome that best matches the human''s feedback intent.","properties":{"outcome":{"description":"The
|
||||
outcome that best matches the feedback. Must be one of: approved, rejected","enum":["approved","rejected"],"title":"Outcome","type":"string"}},"required":["outcome"],"title":"FeedbackOutcome","type":"object","additionalProperties":false},"name":"FeedbackOutcome","strict":true}},"stream":false}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '782'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-helper-method:
|
||||
- beta.chat.completions.parse
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DHDHCheu5DvlB6xjTrEDq0nfzLlrf\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1772994982,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"{\\\"outcome\\\":\\\"approved\\\"}\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
130,\n \"completion_tokens\": 6,\n \"total_tokens\": 136,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_cf6f0a1ff1\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 08 Mar 2026 18:36:22 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '361'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,113 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Based on the following human feedback,
|
||||
determine which outcome best matches their intent.\n\nFeedback: Unclear feedback\n\nPossible
|
||||
outcomes: approved, rejected\n\nRespond with ONLY one of the exact outcome values
|
||||
listed above, nothing else."}],"model":"gpt-4o-mini","response_format":{"type":"json_schema","json_schema":{"schema":{"description":"The
|
||||
outcome that best matches the human''s feedback intent.","properties":{"outcome":{"description":"The
|
||||
outcome that best matches the feedback. Must be one of: approved, rejected","enum":["approved","rejected"],"title":"Outcome","type":"string"}},"required":["outcome"],"title":"FeedbackOutcome","type":"object","additionalProperties":false},"name":"FeedbackOutcome","strict":true}},"stream":false}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '784'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-helper-method:
|
||||
- beta.chat.completions.parse
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DHDHDlji53YRtj69Ulq5E9SjBqccI\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1772994983,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"{\\\"outcome\\\":\\\"rejected\\\"}\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
130,\n \"completion_tokens\": 7,\n \"total_tokens\": 137,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_a1ddba3226\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 08 Mar 2026 18:36:24 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '317'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,113 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Based on the following human feedback,
|
||||
determine which outcome best matches their intent.\n\nFeedback: Looks good\n\nPossible
|
||||
outcomes: approved, rejected\n\nRespond with ONLY one of the exact outcome values
|
||||
listed above, nothing else."}],"model":"gpt-4o-mini","response_format":{"type":"json_schema","json_schema":{"schema":{"description":"The
|
||||
outcome that best matches the human''s feedback intent.","properties":{"outcome":{"description":"The
|
||||
outcome that best matches the feedback. Must be one of: approved, rejected","enum":["approved","rejected"],"title":"Outcome","type":"string"}},"required":["outcome"],"title":"FeedbackOutcome","type":"object","additionalProperties":false},"name":"FeedbackOutcome","strict":true}},"stream":false}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '778'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-helper-method:
|
||||
- beta.chat.completions.parse
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DHDHEVhZlU19TjrqDy0sKeWkKRINn\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1772994984,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"{\\\"outcome\\\":\\\"approved\\\"}\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
129,\n \"completion_tokens\": 6,\n \"total_tokens\": 135,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_a1ddba3226\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 08 Mar 2026 18:36:24 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '253'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -897,7 +897,7 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
"""Test that JSON string response from LLM is correctly parsed."""
|
||||
flow = Flow()
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
with patch("crewai.flow.flow.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
# Simulate LLM returning JSON string (the bug we fixed)
|
||||
mock_llm.call.return_value = '{"outcome": "approved"}'
|
||||
@@ -915,7 +915,7 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
"""Test that plain string response is correctly matched."""
|
||||
flow = Flow()
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
with patch("crewai.flow.flow.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
# Simulate LLM returning plain outcome string
|
||||
mock_llm.call.return_value = "rejected"
|
||||
@@ -933,7 +933,7 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
"""Test that invalid JSON falls back to string matching."""
|
||||
flow = Flow()
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
with patch("crewai.flow.flow.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
# Invalid JSON that contains "approved"
|
||||
mock_llm.call.return_value = "{invalid json but says approved"
|
||||
@@ -951,7 +951,7 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
"""Test that LLM exception triggers fallback to simple prompting."""
|
||||
flow = Flow()
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
with patch("crewai.flow.flow.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
# First call raises, second call succeeds (fallback)
|
||||
mock_llm.call.side_effect = [
|
||||
|
||||
@@ -349,56 +349,38 @@ class TestHumanFeedbackHistory:
|
||||
class TestCollapseToOutcome:
|
||||
"""Tests for the _collapse_to_outcome method."""
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_exact_match(self):
|
||||
"""Test exact match returns the correct outcome."""
|
||||
flow = Flow()
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="I approve this",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
assert result in ("approved", "rejected")
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.call.return_value = "approved"
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="I approve this",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
assert result == "approved"
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_partial_match(self):
|
||||
"""Test partial match finds the outcome in the response."""
|
||||
flow = Flow()
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="Looks good",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
assert result in ("approved", "rejected")
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.call.return_value = "The outcome is approved based on the feedback"
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="Looks good",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
assert result == "approved"
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_fallback_to_first(self):
|
||||
"""Test that unmatched response falls back to first outcome."""
|
||||
flow = Flow()
|
||||
|
||||
with patch("crewai.llm.LLM") as MockLLM:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.call.return_value = "something completely different"
|
||||
MockLLM.return_value = mock_llm
|
||||
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="Unclear feedback",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
assert result == "approved" # First in list
|
||||
result = flow._collapse_to_outcome(
|
||||
feedback="Unclear feedback",
|
||||
outcomes=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
assert result in ("approved", "rejected")
|
||||
|
||||
|
||||
# -- HITL Learning tests --
|
||||
|
||||
Reference in New Issue
Block a user