Compare commits

..

3 Commits

Author SHA1 Message Date
Greyson Lalonde
68c9990eef feat: use emit decorator on external, improve typing 2025-10-22 13:48:37 -04:00
Greyson Lalonde
8d8772d607 chore: cleanup memo0, add typing 2025-10-22 12:52:28 -04:00
Greyson Lalonde
d916bc8695 feat: create generic event emission decorator 2025-10-22 11:41:09 -04:00
24 changed files with 4690 additions and 4516 deletions

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube>=15.0.0",
"requests>=2.32.5",
"docker>=7.1.0",
"crewai==1.2.0",
"crewai==1.1.0",
"lancedb>=0.5.4",
"tiktoken>=0.8.0",
"beautifulsoup4>=4.13.4",

View File

@@ -287,4 +287,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.2.0"
__version__ = "1.1.0"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.2.0",
"crewai-tools==1.1.0",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.2.0"
__version__ = "1.1.0"
_telemetry_submitted = False

View File

@@ -322,7 +322,7 @@ MODELS = {
],
}
DEFAULT_LLM_MODEL = "gpt-4.1-mini"
DEFAULT_LLM_MODEL = "gpt-4o-mini"
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.2.0"
"crewai[tools]==1.1.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.2.0"
"crewai[tools]==1.1.0"
]
[project.scripts]

View File

@@ -0,0 +1,388 @@
"""Decorators for automatic event lifecycle management.
This module provides decorators that automatically emit started/completed/failed
events for methods, reducing boilerplate code across the codebase.
"""
from collections.abc import Callable
from functools import wraps
import time
from typing import Any, Concatenate, Literal, ParamSpec, TypeVar, TypedDict, cast
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
P = ParamSpec("P")
R = TypeVar("R")
EventPrefix = Literal[
"task",
"memory_save",
"memory_query",
"crew_kickoff",
"crew_train",
"crew_test",
]
EventParams = dict[str, Any]
StartedParamsFn = Callable[[Any, tuple[Any, ...], dict[str, Any]], EventParams]
CompletedParamsFn = Callable[
[Any, tuple[Any, ...], dict[str, Any], Any, float], EventParams
]
FailedParamsFn = Callable[
[Any, tuple[Any, ...], dict[str, Any], Exception], EventParams
]
class LifecycleEventClasses(TypedDict):
"""Mapping of lifecycle event types to their corresponding event classes."""
started: type[BaseEvent]
completed: type[BaseEvent]
failed: type[BaseEvent]
class EventClassMap(TypedDict):
"""Mapping of event prefixes to their lifecycle event classes."""
task: LifecycleEventClasses
memory_save: LifecycleEventClasses
memory_query: LifecycleEventClasses
crew_kickoff: LifecycleEventClasses
crew_train: LifecycleEventClasses
crew_test: LifecycleEventClasses
class LifecycleParamExtractors(TypedDict):
"""Parameter extractors for lifecycle events."""
started_params: StartedParamsFn
completed_params: CompletedParamsFn
failed_params: FailedParamsFn
EVENT_CLASS_MAP: EventClassMap = {
"task": {
"started": TaskStartedEvent,
"completed": TaskCompletedEvent,
"failed": TaskFailedEvent,
},
"memory_save": {
"started": MemorySaveStartedEvent,
"completed": MemorySaveCompletedEvent,
"failed": MemorySaveFailedEvent,
},
"memory_query": {
"started": MemoryQueryStartedEvent,
"completed": MemoryQueryCompletedEvent,
"failed": MemoryQueryFailedEvent,
},
"crew_kickoff": {
"started": CrewKickoffStartedEvent,
"completed": CrewKickoffCompletedEvent,
"failed": CrewKickoffFailedEvent,
},
"crew_train": {
"started": CrewTrainStartedEvent,
"completed": CrewTrainCompletedEvent,
"failed": CrewTrainFailedEvent,
},
"crew_test": {
"started": CrewTestStartedEvent,
"completed": CrewTestCompletedEvent,
"failed": CrewTestFailedEvent,
},
}
def _extract_arg(
position: str | int, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
"""Extract argument by name from kwargs or by position from args.
Args:
position: Argument name (str) or positional index (int).
args: Positional arguments tuple.
kwargs: Keyword arguments dict.
Returns:
Extracted argument value or None if not found.
"""
if isinstance(position, str):
return kwargs.get(position)
try:
return args[position]
except IndexError:
return None
def lifecycle_params(
*,
args_map: dict[str, str | int] | None = None,
context: dict[str, Any | Callable[[Any], Any]] | None = None,
result_name: str | None = None,
elapsed_name: str = "elapsed_ms",
) -> LifecycleParamExtractors:
"""Helper to create lifecycle event parameter extractors with reduced boilerplate.
This function generates the three parameter extractors (started_params, completed_params,
failed_params) needed by @with_lifecycle_events, following common patterns and reducing
code duplication.
Args:
args_map: Maps event parameter names to function argument names (str) or positions (int).
Example: {"query": "query", "value": 0} extracts kwargs["query"] and args[0]
context: Static or dynamic context fields included in all events.
Values can be static (Any) or callables that receive self and return a value.
Example: {"source_type": "external_memory", "from_agent": lambda self: self.agent}
result_name: Name for the result in completed_params (e.g., "results", "output").
If None, result is not included in the event.
elapsed_name: Name for elapsed time in completed_params (default: "elapsed_ms").
Returns:
Dictionary with keys "started_params", "completed_params", "failed_params"
containing the appropriate lambda functions for @with_lifecycle_events.
Example:
>>> param_extractors = lifecycle_params(
... args_map={"value": "value", "metadata": "metadata"},
... context={
... "source_type": "external_memory",
... "from_agent": lambda self: self.agent,
... "from_task": lambda self: self.task,
... },
... elapsed_name="save_time_ms",
... )
>>> param_extractors["started_params"] # doctest: +ELLIPSIS
<function lifecycle_params.<locals>.started_params_fn at 0x...>
"""
args_map = args_map or {}
context = context or {}
static_context: EventParams = {}
dynamic_context: dict[str, Callable[[Any], Any]] = {}
for ctx_key, ctx_value in context.items():
if callable(ctx_value):
dynamic_context[ctx_key] = ctx_value
else:
static_context[ctx_key] = ctx_value
def started_params_fn(
self: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> EventParams:
"""Extract parameters for started event.
Args:
self: Instance emitting the event.
args: Positional arguments from decorated method.
kwargs: Keyword arguments from decorated method.
Returns:
Parameters for started event.
"""
params: EventParams = {**static_context}
for param_name, arg_spec in args_map.items():
params[param_name] = _extract_arg(arg_spec, args, kwargs)
for key, func in dynamic_context.items():
params[key] = func(self)
return params
def completed_params_fn(
self: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
result: Any,
elapsed_ms: float,
) -> EventParams:
"""Extract parameters for completed event.
Args:
self: Instance emitting the event.
args: Positional arguments from decorated method.
kwargs: Keyword arguments from decorated method.
result: Return value from decorated method.
elapsed_ms: Elapsed execution time in milliseconds.
Returns:
Parameters for completed event.
"""
params: EventParams = {**static_context}
for param_name, arg_spec in args_map.items():
params[param_name] = _extract_arg(arg_spec, args, kwargs)
if result_name is not None:
params[result_name] = result
params[elapsed_name] = elapsed_ms
for key, func in dynamic_context.items():
params[key] = func(self)
return params
def failed_params_fn(
self: Any, args: tuple[Any, ...], kwargs: dict[str, Any], exc: Exception
) -> EventParams:
"""Extract parameters for failed event.
Args:
self: Instance emitting the event.
args: Positional arguments from decorated method.
kwargs: Keyword arguments from decorated method.
exc: Exception raised during execution.
Returns:
Parameters for failed event.
"""
params: EventParams = {**static_context}
for param_name, arg_spec in args_map.items():
params[param_name] = _extract_arg(arg_spec, args, kwargs)
params["error"] = str(exc)
for key, func in dynamic_context.items():
params[key] = func(self)
return params
return {
"started_params": started_params_fn,
"completed_params": completed_params_fn,
"failed_params": failed_params_fn,
}
def with_lifecycle_events(
prefix: EventPrefix,
*,
args_map: dict[str, str | int] | None = None,
context: dict[str, Any | Callable[[Any], Any]] | None = None,
result_name: str | None = None,
elapsed_name: str = "elapsed_ms",
) -> Callable[[Callable[Concatenate[Any, P], R]], Callable[Concatenate[Any, P], R]]:
"""Decorator to automatically emit lifecycle events (started/completed/failed).
This decorator wraps a method to emit events at different stages of execution:
- StartedEvent: Emitted before method execution
- CompletedEvent: Emitted after successful execution (includes timing via monotonic_ns)
- FailedEvent: Emitted if an exception occurs (re-raises the exception)
Args:
prefix: Event prefix from the EventPrefix Literal type. Determines which
event classes to use (e.g., "task" -> TaskStartedEvent, etc.)
args_map: Maps event parameter names to function argument names (str) or positions (int).
Example: {"query": "query", "value": 0} extracts kwargs["query"] and args[0]
context: Static or dynamic context fields included in all events.
Values can be static (Any) or callables that receive self and return a value.
Example: {"source_type": "external_memory", "from_agent": lambda self: self.agent}
result_name: Name for the result in completed_params (e.g., "results", "output").
If None, result is not included in the event.
elapsed_name: Name for elapsed time in completed_params (default: "elapsed_ms").
Returns:
Decorated function that emits lifecycle events.
Example:
>>> @with_lifecycle_events(
... "memory_save",
... args_map={"value": "value", "metadata": "metadata"},
... context={
... "source_type": "external_memory",
... "from_agent": lambda self: self.agent,
... },
... elapsed_name="save_time_ms",
... )
... def save(self, value: Any, metadata: dict[str, Any] | None = None) -> None:
... pass
"""
param_extractors = lifecycle_params(
args_map=args_map,
context=context,
result_name=result_name,
elapsed_name=elapsed_name,
)
started_params: StartedParamsFn = param_extractors["started_params"]
completed_params: CompletedParamsFn = param_extractors["completed_params"]
failed_params: FailedParamsFn = param_extractors["failed_params"]
event_classes = EVENT_CLASS_MAP[prefix]
def decorator(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], R]:
"""Apply lifecycle event emission to the decorated function.
Args:
func: Function to decorate.
Returns:
Decorated function with lifecycle event emission.
"""
@wraps(func)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R:
"""Execute function with lifecycle event emission.
Args:
self: Instance calling the method.
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Result from the decorated function.
Raises:
Exception: Re-raises any exception after emitting failed event.
"""
started_event_params = started_params(self, args, kwargs)
crewai_event_bus.emit(
self,
event_classes["started"](**started_event_params),
)
start_time = time.monotonic_ns()
try:
result = func(self, *args, **kwargs)
completed_event_params = completed_params(
self,
args,
kwargs,
result,
(time.monotonic_ns() - start_time) / 1_000_000,
)
crewai_event_bus.emit(
self,
event_classes["completed"](**completed_event_params),
)
return result
except Exception as e:
failed_event_params = failed_params(self, args, kwargs, e)
crewai_event_bus.emit(
self,
event_classes["failed"](**failed_event_params),
)
raise
return cast(Callable[Concatenate[Any, P], R], wrapper)
return decorator

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from pyvis.network import Network # type: ignore[import-untyped]
@@ -29,7 +29,7 @@ _printer = Printer()
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow: Flow[Any]) -> None:
def __init__(self, flow: Flow) -> None:
"""
Initialize FlowPlot with a flow object.
@@ -136,7 +136,7 @@ class FlowPlot:
f"Unexpected error during flow visualization: {e!s}"
) from e
finally:
self._cleanup_pyvis_lib(filename)
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html: str) -> str:
"""
@@ -186,33 +186,26 @@ class FlowPlot:
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
@staticmethod
def _cleanup_pyvis_lib(filename: str) -> None:
def _cleanup_pyvis_lib() -> None:
"""
Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis
during network visualization generation. The lib folder is created in the
same directory as the output HTML file.
Parameters
----------
filename : str
The output filename (without .html extension) used for the visualization.
during network visualization generation.
"""
try:
import shutil
output_dir = os.path.dirname(os.path.abspath(filename)) or os.getcwd()
lib_folder = os.path.join(output_dir, "lib")
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
vis_js = os.path.join(lib_folder, "vis-network.min.js")
if os.path.exists(vis_js):
shutil.rmtree(lib_folder)
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
_printer.print(f"Error validating lib folder path: {e}", color="red")
except Exception as e:
_printer.print(f"Error cleaning up lib folder: {e}", color="red")
def plot_flow(flow: Flow[Any], filename: str = "flow_plot") -> None:
def plot_flow(flow: Flow, filename: str = "flow_plot") -> None:
"""
Convenience function to create and save a flow visualization.

View File

@@ -1,8 +1,5 @@
"""HTML template processing and generation for flow visualization diagrams."""
import base64
import re
from typing import Any
from crewai.flow.path_utils import validate_path_exists
@@ -10,7 +7,7 @@ from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path: str, logo_path: str) -> None:
def __init__(self, template_path, logo_path):
"""
Initialize HTMLTemplateHandler with validated template and logo paths.
@@ -32,23 +29,23 @@ class HTMLTemplateHandler:
except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self) -> str:
def read_template(self):
"""Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
def encode_logo(self) -> str:
def encode_logo(self):
"""Convert the logo SVG file to base64 encoded string."""
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
def extract_body_content(self, html: str) -> str:
def extract_body_content(self, html):
"""Extract and return content between body tags from HTML string."""
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items: list[dict[str, Any]]) -> str:
def generate_legend_items_html(self, legend_items):
"""Generate HTML markup for the legend items."""
legend_items_html = ""
for item in legend_items:
@@ -76,9 +73,7 @@ class HTMLTemplateHandler:
"""
return legend_items_html
def generate_final_html(
self, network_body: str, legend_items_html: str, title: str = "Flow Plot"
) -> str:
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
"""Combine all components into final HTML document with network visualization."""
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()

View File

@@ -1,23 +1,4 @@
"""Legend generation for flow visualization diagrams."""
from typing import Any
from crewai.flow.config import FlowColors
def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
"""Generate legend items based on flow colors.
Parameters
----------
colors : FlowColors
Dictionary containing color definitions for flow elements.
Returns
-------
list[dict[str, Any]]
List of legend item dictionaries with labels and styling.
"""
def get_legend_items(colors):
return [
{"label": "Start Method", "color": colors["start"]},
{"label": "Method", "color": colors["method"]},
@@ -43,19 +24,7 @@ def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
]
def generate_legend_items_html(legend_items: list[dict[str, Any]]) -> str:
"""Generate HTML markup for legend items.
Parameters
----------
legend_items : list[dict[str, Any]]
List of legend item dictionaries containing labels and styling.
Returns
-------
str
HTML string containing formatted legend items.
"""
def generate_legend_items_html(legend_items):
legend_items_html = ""
for item in legend_items:
if "border" in item:

View File

@@ -36,29 +36,28 @@ from crewai.flow.utils import (
from crewai.utilities.printer import Printer
_printer = Printer()
def method_calls_crew(method: Any) -> bool:
"""
Check if the method contains a call to `.crew()`, `.kickoff()`, or `.kickoff_async()`.
Check if the method contains a call to `.crew()`.
Parameters
----------
method : Any
The method to analyze for crew or agent execution calls.
The method to analyze for crew() calls.
Returns
-------
bool
True if the method calls .crew(), .kickoff(), or .kickoff_async(), False otherwise.
True if the method calls .crew(), False otherwise.
Notes
-----
Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew', 'kickoff', or 'kickoff_async'.
This includes both traditional Crew execution (.crew()) and Agent/LiteAgent
execution (.kickoff() or .kickoff_async()).
attribute access of 'crew'.
"""
try:
source = inspect.getsource(method)
@@ -69,14 +68,14 @@ def method_calls_crew(method: Any) -> bool:
return False
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew(), .kickoff(), or .kickoff_async() method calls."""
"""AST visitor to detect .crew() method calls."""
def __init__(self) -> None:
def __init__(self):
self.found = False
def visit_Call(self, node: ast.Call) -> None:
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if node.func.attr in ("crew", "kickoff", "kickoff_async"):
if node.func.attr == "crew":
self.found = True
self.generic_visit(node)
@@ -114,7 +113,7 @@ def add_nodes_to_network(
- Regular methods
"""
def human_friendly_label(method_name: str) -> str:
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
node_style: (

View File

@@ -175,11 +175,8 @@ class BedrockCompletion(BaseLLM):
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
**kwargs: Additional parameters (including model_id for cross-region inference)
**kwargs: Additional parameters
"""
# Extract model_id from kwargs if provided (for cross-region inference profiles)
custom_model_id = kwargs.pop("model_id", None)
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
@@ -233,7 +230,7 @@ class BedrockCompletion(BaseLLM):
self.supports_streaming = True
# Handle inference profiles for newer models
self.model_id = custom_model_id if custom_model_id else model
self.model_id = model
def call(
self,

View File

@@ -1,17 +1,9 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.lifecycle_decorator import with_lifecycle_events
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
@@ -19,29 +11,31 @@ from crewai.rag.embeddings.types import ProviderSpec
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
from crewai.crew import Crew
class ExternalMemory(Memory):
def __init__(self, storage: Storage | None = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
from crewai.memory.storage.mem0_storage import Mem0Storage
def _configure_mem0(crew: Crew, config: dict[str, Any]) -> Storage:
from crewai.memory.storage.mem0_storage import Mem0Config, Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
return Mem0Storage(
type="external", crew=crew, config=cast(Mem0Config, cast(object, config))
)
@staticmethod
def external_supported_storages() -> dict[str, Any]:
def external_supported_storages() -> dict[
str, Callable[[Crew, dict[str, Any]], Storage]
]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
) -> Storage:
def create_storage(crew: Crew, embedder_config: ProviderSpec | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -53,115 +47,59 @@ class ExternalMemory(Memory):
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
config = embedder_config.get("config", {})
return supported_storages[provider](crew, cast(dict[str, Any], config))
@with_lifecycle_events(
"memory_save",
args_map={"value": "value", "metadata": "metadata"},
context={
"source_type": "external_memory",
"from_agent": lambda self: self.agent,
"from_task": lambda self: self.task,
},
elapsed_name="save_time_ms",
)
def save(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
super().save(value=item.value, metadata=item.metadata)
start_time = time.time()
try:
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
super().save(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
@with_lifecycle_events(
"memory_query",
args_map={
"query": "query",
"limit": "limit",
"score_threshold": "score_threshold",
},
context={
"source_type": "external_memory",
"from_agent": lambda self: self.agent,
"from_task": lambda self: self.task,
},
result_name="results",
elapsed_name="query_time_ms",
)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = super().search(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="external_memory",
),
)
raise
) -> Any:
return super().search(query=query, limit=limit, score_threshold=score_threshold)
def reset(self) -> None:
self.storage.reset()
def set_crew(self, crew: Any) -> ExternalMemory:
def set_crew(self, crew: Crew) -> ExternalMemory:
super().set_crew(crew)
if not self.storage:

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel
@@ -24,9 +24,6 @@ class Memory(BaseModel):
_agent: Agent | None = None
_task: Task | None = None
def __init__(self, storage: Any, **data: Any):
super().__init__(storage=storage, **data)
@property
def task(self) -> Task | None:
"""Get the current task associated with this memory."""
@@ -62,8 +59,11 @@ class Memory(BaseModel):
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
return cast(
list[Any],
self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
),
)
def set_crew(self, crew: Any) -> Memory:

View File

@@ -1,16 +1,83 @@
from collections import defaultdict
from __future__ import annotations
from collections.abc import Iterable
import os
import re
from typing import Any
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
from mem0 import Memory, MemoryClient # type: ignore[import-untyped]
from crewai.memory.storage.interface import Storage
from crewai.rag.chromadb.utils import _sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
if TYPE_CHECKING:
from crewai.crew import Crew
from crewai.utilities.types import LLMMessage, MessageRole
MAX_AGENT_ID_LENGTH_MEM0: Final[int] = 255
_ASSISTANT_MESSAGE_MARKER: Final[str] = "Final Answer:"
_USER_MESSAGE_PATTERN: Final[re.Pattern[str]] = re.compile(r"User message:\s*(.*)")
class BaseMetadata(TypedDict):
short_term: Literal["short_term"]
long_term: Literal["long_term"]
entities: Literal["entity"]
external: Literal["external"]
BASE_METADATA: Final[BaseMetadata] = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external",
}
MEMORY_TYPE_MAP: Final[dict[str, dict[str, str]]] = {
"short_term": {"type": "short_term"},
"long_term": {"type": "long_term"},
"entities": {"type": "entity"},
"external": {"type": "external"},
}
class BaseParams(TypedDict, total=False):
"""Parameters for Mem0 memory operations."""
metadata: dict[str, Any]
infer: bool
includes: Any
excludes: Any
output_format: str
version: str
run_id: str
user_id: str
agent_id: str
class Mem0Config(TypedDict, total=False):
"""Configuration for Mem0Storage."""
run_id: str
includes: Any
excludes: Any
custom_categories: Any
infer: bool
api_key: str
org_id: str
project_id: str
local_mem0_config: Any
user_id: str
agent_id: str
class Mem0Filter(TypedDict, total=False):
"""Filter dictionary for Mem0 search operations."""
AND: list[dict[str, Any]]
OR: list[dict[str, Any]]
class Mem0Storage(Storage):
@@ -18,33 +85,22 @@ class Mem0Storage(Storage):
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
super().__init__()
self._validate_type(type)
def __init__(
self,
type: Literal["short_term", "long_term", "entities", "external"],
crew: Crew | None = None,
config: Mem0Config | None = None,
) -> None:
self.memory_type = type
self.crew = crew
self.config = config or {}
self._extract_config_values()
self._initialize_memory()
def _validate_type(self, type):
supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. "
f"Must be one of: {', '.join(supported_types)}"
)
def _extract_config_values(self):
self.mem0_run_id = self.config.get("run_id")
self.includes = self.config.get("includes")
self.excludes = self.config.get("excludes")
self.custom_categories = self.config.get("custom_categories")
self.infer = self.config.get("infer", True)
def _initialize_memory(self):
if config is None:
config = {}
self.config: Mem0Config = config
self.mem0_run_id = config.get("run_id")
self.includes = config.get("includes")
self.excludes = config.get("excludes")
self.custom_categories = config.get("custom_categories")
self.infer = config.get("infer", True)
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
org_id = self.config.get("org_id")
project_id = self.config.get("project_id")
@@ -65,47 +121,39 @@ class Mem0Storage(Storage):
else Memory()
)
def _create_filter_for_search(self):
"""
def _create_filter_for_search(self) -> Mem0Filter:
"""Create filter dictionary for search operations.
Returns:
dict: A filter dictionary containing AND conditions for querying data.
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and
mem0_run_id is present.
Filter dictionary containing AND/OR conditions for querying data.
"""
filter = defaultdict(list)
if self.memory_type == "short_term" and self.mem0_run_id:
filter["AND"].append({"run_id": self.mem0_run_id})
else:
user_id = self.config.get("user_id", "")
agent_id = self.config.get("agent_id", "")
return {"AND": [{"run_id": self.mem0_run_id}]}
if user_id and agent_id:
filter["OR"].append({"user_id": user_id})
filter["OR"].append({"agent_id": agent_id})
elif user_id:
filter["AND"].append({"user_id": user_id})
elif agent_id:
filter["AND"].append({"agent_id": agent_id})
return filter
user_id = self.config.get("user_id")
agent_id = self.config.get("agent_id")
if user_id and agent_id:
return {"OR": [{"user_id": user_id}, {"agent_id": agent_id}]}
if user_id:
return {"AND": [{"user_id": user_id}]}
if agent_id:
return {"AND": [{"agent_id": agent_id}]}
return {}
def save(self, value: Any, metadata: dict[str, Any]) -> None:
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
return next(
def _last_content(messages_: Iterable[LLMMessage], role: MessageRole) -> str:
content = next(
(
m.get("content", "")
for m in reversed(list(messages))
for m in reversed(list(messages_))
if m.get("role") == role
),
"",
)
return str(content) if content else ""
conversations = []
messages = metadata.pop("messages", None)
messages: Iterable[LLMMessage] = metadata.pop("messages", [])
if messages:
last_user = _last_content(messages, "user")
last_assistant = _last_content(messages, "assistant")
@@ -120,20 +168,11 @@ class Mem0Storage(Storage):
user_id = self.config.get("user_id", "")
base_metadata = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
params: BaseParams = {
"metadata": {"type": BASE_METADATA[self.memory_type], **metadata},
"infer": self.infer,
}
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
@@ -154,7 +193,7 @@ class Mem0Storage(Storage):
def search(
self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]:
params = {
params: dict[str, Any] = {
"query": query,
"limit": limit,
"version": "v2",
@@ -164,15 +203,8 @@ class Mem0Storage(Storage):
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
memory_type_map = {
"short_term": {"type": "short_term"},
"long_term": {"type": "long_term"},
"entities": {"type": "entity"},
"external": {"type": "external"},
}
if self.memory_type in memory_type_map:
params["metadata"] = memory_type_map[self.memory_type]
if self.memory_type in MEMORY_TYPE_MAP:
params["metadata"] = MEMORY_TYPE_MAP[self.memory_type]
if self.memory_type == "short_term":
params["run_id"] = self.mem0_run_id
@@ -195,11 +227,12 @@ class Mem0Storage(Storage):
return [r for r in results["results"]]
def reset(self):
def reset(self) -> None:
if self.memory:
self.memory.reset()
def _sanitize_role(self, role: str) -> str:
@staticmethod
def _sanitize_role(role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
@@ -210,21 +243,20 @@ class Mem0Storage(Storage):
return ""
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
agents_roles = "".join([self._sanitize_role(agent.role) for agent in agents])
return _sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
name=agents_roles, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)
def _get_assistant_message(self, text: str) -> str:
marker = "Final Answer:"
if marker in text:
return text.split(marker, 1)[1].strip()
@staticmethod
def _get_assistant_message(text: str) -> str:
if _ASSISTANT_MESSAGE_MARKER in text:
return text.split(_ASSISTANT_MESSAGE_MARKER, 1)[1].strip()
return text
def _get_user_message(self, text: str) -> str:
pattern = r"User message:\s*(.*)"
match = re.search(pattern, text)
@staticmethod
def _get_user_message(text: str) -> str:
match = _USER_MESSAGE_PATTERN.search(text)
if match:
return match.group(1).strip()
return text

View File

@@ -29,8 +29,8 @@ def create_llm(
try:
return LLM(model=llm_value)
except Exception as e:
logger.error(f"Error instantiating LLM from string: {e}")
raise e
logger.debug(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None
if llm_value is None:
return _llm_via_environment_or_fallback()
@@ -62,8 +62,8 @@ def create_llm(
)
except Exception as e:
logger.error(f"Error instantiating LLM from unknown object type: {e}")
raise e
logger.debug(f"Error instantiating LLM from unknown object type: {e}")
return None
UNACCEPTED_ATTRIBUTES: Final[list[str]] = [
@@ -176,10 +176,10 @@ def _llm_via_environment_or_fallback() -> LLM | None:
try:
return LLM(**llm_params)
except Exception as e:
logger.error(
logger.debug(
f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}"
)
raise e
return None
def _normalize_key_name(key_name: str) -> str:

View File

@@ -3,6 +3,9 @@
from typing import Any, Literal, TypedDict
MessageRole = Literal["user", "assistant", "system"]
class LLMMessage(TypedDict):
"""Type for formatted LLM messages.
@@ -11,5 +14,5 @@ class LLMMessage(TypedDict):
instead of str | list[dict[str, str]]
"""
role: Literal["user", "assistant", "system"]
role: MessageRole
content: str | list[dict[str, Any]]

View File

@@ -6,7 +6,6 @@ from unittest import mock
from unittest.mock import MagicMock, patch
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.cli.constants import DEFAULT_LLM_MODEL
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge
@@ -136,7 +135,7 @@ def test_agent_with_missing_response_template():
def test_agent_default_values():
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
assert agent.llm.model == DEFAULT_LLM_MODEL
assert agent.llm.model == "gpt-4o-mini"
assert agent.allow_delegation is False
@@ -226,7 +225,7 @@ def test_logging_tool_usage():
verbose=True,
)
assert agent.llm.model == DEFAULT_LLM_MODEL
assert agent.llm.model == "gpt-4o-mini"
assert agent.tools_handler.last_used_tool is None
task = Task(
description="What is 3 times 4?",

View File

@@ -18,7 +18,6 @@ def mock_aws_credentials():
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
"AWS_DEFAULT_REGION": "us-east-1"
}):
import crewai.llms.providers.bedrock.completion
# Mock boto3 Session to prevent actual AWS connections
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
# Create mock session instance
@@ -737,76 +736,3 @@ def test_bedrock_client_error_handling():
with pytest.raises(RuntimeError) as exc_info:
llm.call("Hello")
assert "throttled" in str(exc_info.value).lower()
def test_bedrock_cross_region_inference_profile():
"""
Test that Bedrock supports cross-region inference profiles with model_id parameter.
This tests the fix for issue #3791 where cross-region inference profiles
(which require using ARN as model_id) were not working in version 1.20.0.
When using cross-region inference profiles, users need to:
1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0")
2. Set model_id to the inference profile ARN
The BedrockCompletion should use the model_id parameter when provided,
not the model parameter, for the actual API call.
"""
# Test with cross-region inference profile ARN
inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0"
llm = LLM(
model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0",
model_id=inference_profile_arn,
temperature=0.3,
max_tokens=4000,
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model_id == inference_profile_arn
assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0"
# Verify that the client.converse call would use the correct model_id
with patch.object(llm.client, 'converse') as mock_converse:
mock_converse.return_value = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'Test response'}]
}
},
'usage': {
'inputTokens': 10,
'outputTokens': 5,
'totalTokens': 15
}
}
llm.call("Test message")
# Verify the converse call was made with the inference profile ARN
mock_converse.assert_called_once()
call_kwargs = mock_converse.call_args[1]
assert call_kwargs['modelId'] == inference_profile_arn
def test_bedrock_model_id_parameter_takes_precedence():
"""
Test that when both model and model_id are provided, model_id takes precedence
for the actual API call, while model is used for internal identification.
"""
custom_model_id = "custom-model-identifier"
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
model_id=custom_model_id,
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model_id == custom_model_id
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"

View File

@@ -850,31 +850,6 @@ def test_flow_plotting():
assert isinstance(received_events[0].timestamp, datetime)
def test_method_calls_crew_detection():
"""Test that method_calls_crew() detects .crew(), .kickoff(), and .kickoff_async() calls."""
from crewai.flow.visualization_utils import method_calls_crew
from crewai import Agent
# Test with a real Flow that uses agent.kickoff()
class FlowWithAgentKickoff(Flow):
@start()
def run_agent(self):
agent = Agent(role="test", goal="test", backstory="test")
return agent.kickoff("query")
flow = FlowWithAgentKickoff()
assert method_calls_crew(flow.run_agent) is True
# Test with a Flow that has no crew/agent calls
class FlowWithoutCrewCalls(Flow):
@start()
def simple_method(self):
return "Just a regular method"
flow2 = FlowWithoutCrewCalls()
assert method_calls_crew(flow2.simple_method) is False
def test_multiple_routers_from_same_trigger():
"""Test that multiple routers triggered by the same method all activate their listeners."""
execution_order = []

View File

@@ -1,79 +1,77 @@
import os
from typing import Any
from unittest.mock import patch
from crewai.cli.constants import DEFAULT_LLM_MODEL
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
import pytest
def test_create_llm_with_llm_instance() -> None:
try:
from litellm.exceptions import BadRequestError
except ImportError:
BadRequestError = Exception
def test_create_llm_with_llm_instance():
existing_llm = LLM(model="gpt-4o")
llm = create_llm(llm_value=existing_llm)
assert llm is existing_llm
def test_create_llm_with_valid_model_string():
llm = create_llm(llm_value="gpt-4o")
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
def test_create_llm_with_invalid_model_string():
# For invalid model strings, create_llm succeeds but call() fails with API error
llm = create_llm(llm_value="invalid-model")
assert llm is not None
assert isinstance(llm, BaseLLM)
# The error should occur when making the actual API call
# We expect some kind of API error (NotFoundError, etc.)
with pytest.raises(Exception): # noqa: B017
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_unknown_object_missing_attributes():
class UnknownObject:
pass
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
# Should succeed because str(unknown_obj) provides a model name
assert llm is not None
assert isinstance(llm, BaseLLM)
def test_create_llm_with_none_uses_default_model():
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
existing_llm = LLM(model="gpt-4o")
llm = create_llm(llm_value=existing_llm)
assert llm is existing_llm
def test_create_llm_with_valid_model_string() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
llm = create_llm(llm_value="gpt-4o")
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
def test_create_llm_with_invalid_model_string() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
# For invalid model strings, create_llm succeeds but call() fails with API error
llm = create_llm(llm_value="invalid-model")
assert llm is not None
assert isinstance(llm, BaseLLM)
# The error should occur when making the actual API call
# We expect some kind of API error (NotFoundError, etc.)
with pytest.raises(Exception): # noqa: B017
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_unknown_object_missing_attributes() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
class UnknownObject:
pass
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
# Should succeed because str(unknown_obj) provides a model name
assert llm is not None
assert isinstance(llm, BaseLLM)
def test_create_llm_with_none_uses_default_model() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
with patch("crewai.utilities.llm_utils.DEFAULT_LLM_MODEL", DEFAULT_LLM_MODEL):
with patch("crewai.utilities.llm_utils.DEFAULT_LLM_MODEL", "gpt-4o-mini"):
llm = create_llm(llm_value=None)
assert isinstance(llm, BaseLLM)
assert llm.model == DEFAULT_LLM_MODEL
assert llm.model == "gpt-4o-mini"
def test_create_llm_with_unknown_object() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
class UnknownObject:
model_name = "gpt-4o"
temperature = 0.7
max_tokens = 1500
def test_create_llm_with_unknown_object():
class UnknownObject:
model_name = "gpt-4o"
temperature = 0.7
max_tokens = 1500
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
assert llm.temperature == 0.7
if hasattr(llm, 'max_tokens'):
assert llm.max_tokens == 1500
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
assert llm.temperature == 0.7
assert llm.max_tokens == 1500
def test_create_llm_from_env_with_unaccepted_attributes() -> None:
def test_create_llm_from_env_with_unaccepted_attributes():
with patch.dict(
os.environ,
{
@@ -92,47 +90,25 @@ def test_create_llm_from_env_with_unaccepted_attributes() -> None:
assert not hasattr(llm, "AWS_REGION_NAME")
def test_create_llm_with_partial_attributes() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
class PartialAttributes:
model_name = "gpt-4o"
# temperature is missing
def test_create_llm_with_partial_attributes():
class PartialAttributes:
model_name = "gpt-4o"
# temperature is missing
obj = PartialAttributes()
llm = create_llm(llm_value=obj)
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
assert llm.temperature is None # Should handle missing attributes gracefully
obj = PartialAttributes()
llm = create_llm(llm_value=obj)
assert isinstance(llm, BaseLLM)
assert llm.model == "gpt-4o"
assert llm.temperature is None # Should handle missing attributes gracefully
def test_create_llm_with_invalid_type() -> None:
with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-key"}, clear=True):
# For integers, create_llm succeeds because str(42) becomes "42"
llm = create_llm(llm_value=42)
assert llm is not None
assert isinstance(llm, BaseLLM)
assert llm.model == "42"
def test_create_llm_with_invalid_type():
# For integers, create_llm succeeds because str(42) becomes "42"
llm = create_llm(llm_value=42)
assert llm is not None
assert isinstance(llm, BaseLLM)
assert llm.model == "42"
# The error should occur when making the actual API call
with pytest.raises(Exception): # noqa: B017
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_openai_missing_api_key() -> None:
"""Test that create_llm raises error when OpenAI API key is missing"""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises((ValueError, ImportError)) as exc_info:
create_llm(llm_value="gpt-4o")
error_message = str(exc_info.value).lower()
assert "openai_api_key" in error_message or "api_key" in error_message
def test_create_llm_anthropic_missing_dependency() -> None:
"""Test that create_llm raises error when Anthropic dependency is missing"""
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "fake-key"}, clear=True):
with patch("crewai.llm.LLM.__new__", side_effect=ImportError('Anthropic native provider not available, to install: uv add "crewai[anthropic]"')):
with pytest.raises(ImportError) as exc_info:
create_llm(llm_value="anthropic/claude-3-sonnet")
assert "Anthropic native provider not available, to install: uv add \"crewai[anthropic]\"" in str(exc_info.value)
# The error should occur when making the actual API call
with pytest.raises(Exception): # noqa: B017
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.2.0"
__version__ = "1.1.0"

8014
uv.lock generated

File diff suppressed because it is too large Load Diff