mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-14 23:12:37 +00:00
Compare commits
6 Commits
1.10.0
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83afb17cf9 | ||
|
|
01df1ef3cf | ||
|
|
358fd92e6b | ||
|
|
c4d4ea6c71 | ||
|
|
24c68d4053 | ||
|
|
320326e3e5 |
@@ -106,15 +106,6 @@ There are different places in CrewAI code where you can specify the model to use
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Info>
|
||||
CrewAI provides native SDK integrations for OpenAI, Anthropic, Google (Gemini API), Azure, and AWS Bedrock — no extra install needed beyond the provider-specific extras (e.g. `uv add "crewai[openai]"`).
|
||||
|
||||
All other providers are powered by **LiteLLM**. If you plan to use any of them, add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Info>
|
||||
|
||||
## Provider Configuration Examples
|
||||
|
||||
CrewAI supports a multitude of LLM providers, each offering unique features, authentication methods, and model capabilities.
|
||||
@@ -284,11 +275,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | Text, Image | Text |
|
||||
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | Text | Text |
|
||||
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | Text | Text |
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Anthropic">
|
||||
@@ -585,11 +571,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
| gemini-1.5-flash | 1M tokens | Balanced multimodal model, good for most tasks |
|
||||
| gemini-1.5-flash-8B | 1M tokens | Fastest, most cost-efficient, good for high-frequency tasks |
|
||||
| gemini-1.5-pro | 2M tokens | Best performing, wide variety of reasoning tasks including logical reasoning, coding, and creative collaboration |
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Azure">
|
||||
@@ -785,11 +766,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
model="sagemaker/<my-endpoint>"
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Mistral">
|
||||
@@ -805,11 +781,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Nvidia NIM">
|
||||
@@ -896,11 +867,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
| rakuten/rakutenai-7b-instruct | 1,024 tokens | Advanced state-of-the-art LLM with language understanding, superior reasoning, and text generation. |
|
||||
| rakuten/rakutenai-7b-chat | 1,024 tokens | Advanced state-of-the-art LLM with language understanding, superior reasoning, and text generation. |
|
||||
| baichuan-inc/baichuan2-13b-chat | 4,096 tokens | Support Chinese and English chat, coding, math, instruction following, solving quizzes |
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
|
||||
@@ -941,11 +907,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
|
||||
# ...
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Groq">
|
||||
@@ -967,11 +928,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
| Llama 3.1 70B/8B | 131,072 tokens | High-performance, large context tasks |
|
||||
| Llama 3.2 Series | 8,192 tokens | General-purpose tasks |
|
||||
| Mixtral 8x7B | 32,768 tokens | Balanced performance and context |
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="IBM watsonx.ai">
|
||||
@@ -994,11 +950,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
base_url="https://api.watsonx.ai/v1"
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Ollama (Local LLMs)">
|
||||
@@ -1012,11 +963,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Fireworks AI">
|
||||
@@ -1032,11 +978,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Perplexity AI">
|
||||
@@ -1052,11 +993,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
base_url="https://api.perplexity.ai/"
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Hugging Face">
|
||||
@@ -1071,11 +1007,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="SambaNova">
|
||||
@@ -1099,11 +1030,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
| Llama 3.2 Series | 8,192 tokens | General-purpose, multimodal tasks |
|
||||
| Llama 3.3 70B | Up to 131,072 tokens | High-performance and output quality |
|
||||
| Qwen2 familly | 8,192 tokens | High-performance and output quality |
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Cerebras">
|
||||
@@ -1129,11 +1055,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
- Good balance of speed and quality
|
||||
- Support for long context windows
|
||||
</Info>
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Open Router">
|
||||
@@ -1156,11 +1077,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
- openrouter/deepseek/deepseek-r1
|
||||
- openrouter/deepseek/deepseek-chat
|
||||
</Info>
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Nebius AI Studio">
|
||||
@@ -1183,11 +1099,6 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
- Competitive pricing
|
||||
- Good balance of speed and quality
|
||||
</Info>
|
||||
|
||||
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Connect CrewAI to LLMs
|
||||
|
||||
CrewAI connects to LLMs through native SDK integrations for the most popular providers (OpenAI, Anthropic, Google Gemini, Azure, and AWS Bedrock), and uses LiteLLM as a flexible fallback for all other providers.
|
||||
CrewAI uses LiteLLM to connect to a wide variety of Language Models (LLMs). This integration provides extensive versatility, allowing you to use models from numerous providers with a simple, unified interface.
|
||||
|
||||
<Note>
|
||||
By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set.
|
||||
@@ -41,14 +41,6 @@ LiteLLM supports a wide range of providers, including but not limited to:
|
||||
|
||||
For a complete and up-to-date list of supported providers, please refer to the [LiteLLM Providers documentation](https://docs.litellm.ai/docs/providers).
|
||||
|
||||
<Info>
|
||||
To use any provider not covered by a native integration, add LiteLLM as a dependency to your project:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
Native providers (OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock) use their own SDK extras — see the [Provider Configuration Examples](/en/concepts/llms#provider-configuration-examples).
|
||||
</Info>
|
||||
|
||||
## Changing the LLM
|
||||
|
||||
To use a different LLM with your CrewAI agents, you have several options:
|
||||
|
||||
@@ -35,7 +35,7 @@ Visit [app.crewai.com](https://app.crewai.com) and create your free account. Thi
|
||||
If you haven't already, install CrewAI with the CLI tools:
|
||||
|
||||
```bash
|
||||
uv add 'crewai[tools]'
|
||||
uv add crewai[tools]
|
||||
```
|
||||
|
||||
Then authenticate your CLI with your CrewAI AMP account:
|
||||
|
||||
@@ -105,15 +105,6 @@ CrewAI 코드 내에는 사용할 모델을 지정할 수 있는 여러 위치
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Info>
|
||||
CrewAI는 OpenAI, Anthropic, Google (Gemini API), Azure, AWS Bedrock에 대해 네이티브 SDK 통합을 제공합니다 — 제공자별 extras(예: `uv add "crewai[openai]"`) 외에 추가 설치가 필요하지 않습니다.
|
||||
|
||||
그 외 모든 제공자는 **LiteLLM**을 통해 지원됩니다. 이를 사용하려면 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Info>
|
||||
|
||||
## 공급자 구성 예시
|
||||
|
||||
CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양한 LLM 공급자를 지원합니다.
|
||||
@@ -223,11 +214,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | 텍스트, 이미지 | 텍스트 |
|
||||
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | 텍스트 | 텍스트 |
|
||||
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | 텍스트 | 텍스트 |
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Anthropic">
|
||||
@@ -368,11 +354,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
| gemini-1.5-flash | 1M 토큰 | 밸런스 잡힌 멀티모달 모델, 대부분의 작업에 적합 |
|
||||
| gemini-1.5-flash-8B | 1M 토큰 | 가장 빠르고, 비용 효율적, 고빈도 작업에 적합 |
|
||||
| gemini-1.5-pro | 2M 토큰 | 최고의 성능, 논리적 추론, 코딩, 창의적 협업 등 다양한 추론 작업에 적합 |
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Azure">
|
||||
@@ -458,11 +439,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
model="sagemaker/<my-endpoint>"
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Mistral">
|
||||
@@ -478,11 +454,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Nvidia NIM">
|
||||
@@ -569,11 +540,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
| rakuten/rakutenai-7b-instruct | 1,024 토큰 | 언어 이해, 추론, 텍스트 생성이 탁월한 최첨단 LLM |
|
||||
| rakuten/rakutenai-7b-chat | 1,024 토큰 | 언어 이해, 추론, 텍스트 생성이 탁월한 최첨단 LLM |
|
||||
| baichuan-inc/baichuan2-13b-chat | 4,096 토큰 | 중국어 및 영어 대화, 코딩, 수학, 지시 따르기, 퀴즈 풀이 지원 |
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
|
||||
@@ -614,11 +580,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
|
||||
# ...
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Groq">
|
||||
@@ -640,11 +601,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
| Llama 3.1 70B/8B| 131,072 토큰 | 고성능, 대용량 문맥 작업 |
|
||||
| Llama 3.2 Series| 8,192 토큰 | 범용 작업 |
|
||||
| Mixtral 8x7B | 32,768 토큰 | 성능과 문맥의 균형 |
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="IBM watsonx.ai">
|
||||
@@ -667,11 +623,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
base_url="https://api.watsonx.ai/v1"
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Ollama (Local LLMs)">
|
||||
@@ -685,11 +636,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Fireworks AI">
|
||||
@@ -705,11 +651,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Perplexity AI">
|
||||
@@ -725,11 +666,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
base_url="https://api.perplexity.ai/"
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Hugging Face">
|
||||
@@ -744,11 +680,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
)
|
||||
```
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="SambaNova">
|
||||
@@ -772,11 +703,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
| Llama 3.2 Series| 8,192 토큰 | 범용, 멀티모달 작업 |
|
||||
| Llama 3.3 70B | 최대 131,072 토큰 | 고성능, 높은 출력 품질 |
|
||||
| Qwen2 familly | 8,192 토큰 | 고성능, 높은 출력 품질 |
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Cerebras">
|
||||
@@ -802,11 +728,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
- 속도와 품질의 우수한 밸런스
|
||||
- 긴 컨텍스트 윈도우 지원
|
||||
</Info>
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Open Router">
|
||||
@@ -829,11 +750,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
- openrouter/deepseek/deepseek-r1
|
||||
- openrouter/deepseek/deepseek-chat
|
||||
</Info>
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Nebius AI Studio">
|
||||
@@ -856,11 +772,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
- 경쟁력 있는 가격
|
||||
- 속도와 품질의 우수한 밸런스
|
||||
</Info>
|
||||
|
||||
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## CrewAI를 LLM에 연결하기
|
||||
|
||||
CrewAI는 가장 인기 있는 제공자(OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock)에 대해 네이티브 SDK 통합을 통해 LLM에 연결하며, 그 외 모든 제공자에 대해서는 LiteLLM을 유연한 폴백으로 사용합니다.
|
||||
CrewAI는 LiteLLM을 사용하여 다양한 언어 모델(LLM)에 연결합니다. 이 통합은 높은 다양성을 제공하여, 여러 공급자의 모델을 간단하고 통합된 인터페이스로 사용할 수 있게 해줍니다.
|
||||
|
||||
<Note>
|
||||
기본적으로 CrewAI는 `gpt-4o-mini` 모델을 사용합니다. 이는 `OPENAI_MODEL_NAME` 환경 변수에 의해 결정되며, 설정되지 않은 경우 기본값은 "gpt-4o-mini"입니다.
|
||||
@@ -41,14 +41,6 @@ LiteLLM은 다음을 포함하되 이에 국한되지 않는 다양한 프로바
|
||||
|
||||
지원되는 프로바이더의 전체 및 최신 목록은 [LiteLLM 프로바이더 문서](https://docs.litellm.ai/docs/providers)를 참조하세요.
|
||||
|
||||
<Info>
|
||||
네이티브 통합에서 지원하지 않는 제공자를 사용하려면 LiteLLM을 프로젝트에 의존성으로 추가하세요:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
네이티브 제공자(OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock)는 자체 SDK extras를 사용합니다 — [공급자 구성 예시](/ko/concepts/llms#공급자-구성-예시)를 참조하세요.
|
||||
</Info>
|
||||
|
||||
## LLM 변경하기
|
||||
|
||||
CrewAI agent에서 다른 LLM을 사용하려면 여러 가지 방법이 있습니다:
|
||||
|
||||
@@ -35,7 +35,7 @@ crewai login
|
||||
아직 설치하지 않았다면 CLI 도구와 함께 CrewAI를 설치하세요:
|
||||
|
||||
```bash
|
||||
uv add 'crewai[tools]'
|
||||
uv add crewai[tools]
|
||||
```
|
||||
|
||||
그런 다음 CrewAI AMP 계정으로 CLI를 인증하세요:
|
||||
|
||||
@@ -105,15 +105,6 @@ Existem diferentes locais no código do CrewAI onde você pode especificar o mod
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Info>
|
||||
O CrewAI oferece integrações nativas via SDK para OpenAI, Anthropic, Google (Gemini API), Azure e AWS Bedrock — sem necessidade de instalação extra além dos extras específicos do provedor (ex.: `uv add "crewai[openai]"`).
|
||||
|
||||
Todos os outros provedores são alimentados pelo **LiteLLM**. Se você planeja usar algum deles, adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Info>
|
||||
|
||||
## Exemplos de Configuração de Provedores
|
||||
|
||||
O CrewAI suporta uma grande variedade de provedores de LLM, cada um com recursos, métodos de autenticação e capacidades de modelo únicos.
|
||||
@@ -223,11 +214,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | Texto, Imagem | Texto |
|
||||
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | Texto | Texto |
|
||||
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | Texto | Texto |
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Anthropic">
|
||||
@@ -368,11 +354,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
| gemini-1.5-flash | 1M tokens | Modelo multimodal equilibrado, bom para maioria das tarefas |
|
||||
| gemini-1.5-flash-8B | 1M tokens | Mais rápido, mais eficiente em custo, adequado para tarefas de alta frequência |
|
||||
| gemini-1.5-pro | 2M tokens | Melhor desempenho para uma ampla variedade de tarefas de raciocínio, incluindo lógica, codificação e colaboração criativa |
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Azure">
|
||||
@@ -457,11 +438,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
model="sagemaker/<my-endpoint>"
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Mistral">
|
||||
@@ -477,11 +453,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Nvidia NIM">
|
||||
@@ -568,11 +539,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
| rakuten/rakutenai-7b-instruct | 1.024 tokens | LLM topo de linha, compreensão, raciocínio e geração textual.|
|
||||
| rakuten/rakutenai-7b-chat | 1.024 tokens | LLM topo de linha, compreensão, raciocínio e geração textual.|
|
||||
| baichuan-inc/baichuan2-13b-chat | 4.096 tokens | Suporte a chat em chinês/inglês, programação, matemática, seguir instruções, resolver quizzes.|
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
|
||||
@@ -613,11 +579,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
|
||||
# ...
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Groq">
|
||||
@@ -639,11 +600,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
| Llama 3.1 70B/8B | 131.072 tokens | Alta performance e tarefas de contexto grande|
|
||||
| Llama 3.2 Série | 8.192 tokens | Tarefas gerais |
|
||||
| Mixtral 8x7B | 32.768 tokens | Equilíbrio entre performance e contexto |
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="IBM watsonx.ai">
|
||||
@@ -666,11 +622,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
base_url="https://api.watsonx.ai/v1"
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Ollama (LLMs Locais)">
|
||||
@@ -684,11 +635,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Fireworks AI">
|
||||
@@ -704,11 +650,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Perplexity AI">
|
||||
@@ -724,11 +665,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
base_url="https://api.perplexity.ai/"
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Hugging Face">
|
||||
@@ -743,11 +679,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
)
|
||||
```
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="SambaNova">
|
||||
@@ -771,11 +702,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
| Llama 3.2 Série | 8.192 tokens | Tarefas gerais e multimodais |
|
||||
| Llama 3.3 70B | Até 131.072 tokens | Desempenho e qualidade de saída elevada |
|
||||
| Família Qwen2 | 8.192 tokens | Desempenho e qualidade de saída elevada |
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Cerebras">
|
||||
@@ -801,11 +727,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
- Equilíbrio entre velocidade e qualidade
|
||||
- Suporte a longas janelas de contexto
|
||||
</Info>
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Open Router">
|
||||
@@ -828,11 +749,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
- openrouter/deepseek/deepseek-r1
|
||||
- openrouter/deepseek/deepseek-chat
|
||||
</Info>
|
||||
|
||||
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Conecte o CrewAI a LLMs
|
||||
|
||||
O CrewAI conecta-se a LLMs por meio de integrações nativas via SDK para os provedores mais populares (OpenAI, Anthropic, Google Gemini, Azure e AWS Bedrock), e usa o LiteLLM como alternativa flexível para todos os demais provedores.
|
||||
O CrewAI utiliza o LiteLLM para conectar-se a uma grande variedade de Modelos de Linguagem (LLMs). Essa integração proporciona grande versatilidade, permitindo que você utilize modelos de inúmeros provedores por meio de uma interface simples e unificada.
|
||||
|
||||
<Note>
|
||||
Por padrão, o CrewAI usa o modelo `gpt-4o-mini`. Isso é determinado pela variável de ambiente `OPENAI_MODEL_NAME`, que tem como padrão "gpt-4o-mini" se não for definida.
|
||||
@@ -40,14 +40,6 @@ O LiteLLM oferece suporte a uma ampla gama de provedores, incluindo, mas não se
|
||||
|
||||
Para uma lista completa e sempre atualizada dos provedores suportados, consulte a [documentação de Provedores do LiteLLM](https://docs.litellm.ai/docs/providers).
|
||||
|
||||
<Info>
|
||||
Para usar qualquer provedor não coberto por uma integração nativa, adicione o LiteLLM como dependência ao seu projeto:
|
||||
```bash
|
||||
uv add 'crewai[litellm]'
|
||||
```
|
||||
Provedores nativos (OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock) usam seus próprios extras de SDK — consulte os [Exemplos de Configuração de Provedores](/pt-BR/concepts/llms#exemplos-de-configuração-de-provedores).
|
||||
</Info>
|
||||
|
||||
## Alterando a LLM
|
||||
|
||||
Para utilizar uma LLM diferente com seus agentes CrewAI, você tem várias opções:
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.10.0"
|
||||
__version__ = "1.9.3"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.10.0",
|
||||
"crewai==1.9.3",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.10.0"
|
||||
__version__ = "1.9.3"
|
||||
|
||||
@@ -20117,6 +20117,18 @@
|
||||
"humanized_name": "Web Automation Tool",
|
||||
"init_params_schema": {
|
||||
"$defs": {
|
||||
"AvailableModel": {
|
||||
"enum": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"computer-use-preview",
|
||||
"gemini-2.0-flash"
|
||||
],
|
||||
"title": "AvailableModel",
|
||||
"type": "string"
|
||||
},
|
||||
"EnvVar": {
|
||||
"properties": {
|
||||
"default": {
|
||||
@@ -20194,6 +20206,17 @@
|
||||
"default": null,
|
||||
"title": "Model Api Key"
|
||||
},
|
||||
"model_name": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/$defs/AvailableModel"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"default": "claude-3-7-sonnet-latest"
|
||||
},
|
||||
"project_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.10.0",
|
||||
"crewai-tools==1.9.3",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -10,7 +10,6 @@ from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
@@ -41,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.10.0"
|
||||
__version__ = "1.9.3"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -72,6 +71,25 @@ def _track_install_async() -> None:
|
||||
|
||||
|
||||
_track_install_async()
|
||||
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"Memory": ("crewai.memory.unified_memory", "Memory"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazily import heavy modules (e.g. Memory → lancedb) on first access."""
|
||||
if name in _LAZY_IMPORTS:
|
||||
module_path, attr = _LAZY_IMPORTS[name]
|
||||
import importlib
|
||||
|
||||
mod = importlib.import_module(module_path)
|
||||
val = getattr(mod, attr)
|
||||
globals()[name] = val
|
||||
return val
|
||||
raise AttributeError(f"module 'crewai' has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"Agent",
|
||||
|
||||
@@ -8,9 +8,11 @@ import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -59,8 +61,16 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import MCPServerConfig
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.mcp import (
|
||||
MCPClient,
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
@@ -101,8 +111,18 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT: Final[int] = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
_passthrough_exceptions: tuple[type[Exception], ...] = ()
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
"""Represents an agent in a system.
|
||||
@@ -134,7 +154,7 @@ class Agent(BaseAgent):
|
||||
model_config = ConfigDict()
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_mcp_resolver: MCPToolResolver | None = PrivateAttr(default=None)
|
||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
@@ -364,9 +384,9 @@ class Agent(BaseAgent):
|
||||
)
|
||||
if unified_memory is not None:
|
||||
query = task.description
|
||||
matches = unified_memory.recall(query, limit=5)
|
||||
matches = unified_memory.recall(query, limit=10)
|
||||
if matches:
|
||||
memory = "Relevant memories:\n" + "\n".join(
|
||||
memory = "Relevant memories:\n" + "\n\n".join(
|
||||
m.format() for m in matches
|
||||
)
|
||||
if memory.strip() != "":
|
||||
@@ -914,17 +934,544 @@ class Agent(BaseAgent):
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools.
|
||||
|
||||
Delegates to :class:`~crewai.mcp.tool_resolver.MCPToolResolver`.
|
||||
Supports both string references (backwards compatible) and structured
|
||||
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
|
||||
Args:
|
||||
mcps: List of MCP server references (strings) or configurations.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances from MCP servers.
|
||||
"""
|
||||
self._cleanup_mcp_clients()
|
||||
self._mcp_resolver = MCPToolResolver(agent=self, logger=self._logger)
|
||||
return self._mcp_resolver.resolve(mcps)
|
||||
all_tools = []
|
||||
clients = []
|
||||
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str):
|
||||
tools = self._get_mcp_tools_from_string(mcp_config)
|
||||
else:
|
||||
tools, client = self._get_native_mcp_tools(mcp_config)
|
||||
if client:
|
||||
clients.append(client)
|
||||
|
||||
all_tools.extend(tools)
|
||||
|
||||
# Store clients for cleanup
|
||||
self._mcp_clients.extend(clients)
|
||||
return all_tools
|
||||
|
||||
def _cleanup_mcp_clients(self) -> None:
|
||||
"""Cleanup MCP client connections after task execution."""
|
||||
if self._mcp_resolver is not None:
|
||||
self._mcp_resolver.cleanup()
|
||||
self._mcp_resolver = None
|
||||
if not self._mcp_clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._mcp_clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._mcp_clients.clear()
|
||||
|
||||
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from legacy string-based MCP references.
|
||||
|
||||
This method maintains backwards compatibility with string-based
|
||||
MCP references (https://... and crewai-amp:...).
|
||||
|
||||
Args:
|
||||
mcp_ref: String reference to MCP server.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances.
|
||||
"""
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
return self._get_amp_mcp_tools(mcp_ref)
|
||||
if mcp_ref.startswith("https://"):
|
||||
return self._get_external_mcp_tools(mcp_ref)
|
||||
return []
|
||||
|
||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
# Parse server URL and optional tool name
|
||||
if "#" in mcp_ref:
|
||||
server_url, specific_tool = mcp_ref.split("#", 1)
|
||||
else:
|
||||
server_url, specific_tool = mcp_ref, None
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
|
||||
try:
|
||||
# Get tool schemas with timeout and error handling
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
|
||||
if not tool_schemas:
|
||||
self._logger.log(
|
||||
"warning", f"No tools discovered from MCP server: {server_url}"
|
||||
)
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
# Skip if specific tool requested and this isn't it
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params=server_params,
|
||||
tool_name=tool_name,
|
||||
tool_schema=schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(wrapper)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
|
||||
)
|
||||
continue
|
||||
|
||||
if specific_tool and not tools:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
|
||||
)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to connect to MCP server {server_url}: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _get_native_mcp_tools(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Get tools from MCP server using structured configuration.
|
||||
|
||||
This method creates an MCP client based on the configuration type,
|
||||
connects to the server, discovers tools, applies filtering, and
|
||||
returns wrapped tools along with the client instance for cleanup.
|
||||
|
||||
Args:
|
||||
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
"""Async helper to connect and list tools in same event loop."""
|
||||
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
# Small delay to allow background tasks to finish cleanup
|
||||
# This helps prevent "cancel scope in different task" errors
|
||||
# when asyncio.run() closes the event loop
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# Convert inputSchema to Pydantic model if present
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
amp_part = amp_ref.replace("crewai-amp:", "")
|
||||
if "#" in amp_part:
|
||||
mcp_name, specific_tool = amp_part.split("#", 1)
|
||||
else:
|
||||
mcp_name, specific_tool = amp_part, None
|
||||
|
||||
# Call AMP API to get MCP server URLs
|
||||
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
|
||||
|
||||
tools = []
|
||||
for server_config in mcp_servers:
|
||||
server_ref = server_config["url"]
|
||||
if specific_tool:
|
||||
server_ref += f"#{specific_tool}"
|
||||
server_tools = self._get_external_mcp_tools(server_ref)
|
||||
tools.extend(server_tools)
|
||||
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _extract_server_name(server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace(".", "_")
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
# Check cache first
|
||||
cache_key = server_url
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if current_time - cache_time < _cache_ttl:
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
|
||||
# Cache successful results
|
||||
_mcp_schema_cache[cache_key] = (schemas, current_time)
|
||||
|
||||
return schemas
|
||||
except Exception as e:
|
||||
# Log warning but don't raise - this allows graceful degradation
|
||||
self._logger.log(
|
||||
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
self._discover_mcp_tools_with_timeout, server_url
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
# Execute single attempt outside try-except loop structure
|
||||
result, error, should_retry = await self._attempt_mcp_discovery(
|
||||
operation_func, server_url
|
||||
)
|
||||
|
||||
# Success case - return immediately
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Non-retryable error - raise immediately
|
||||
if not should_retry:
|
||||
raise RuntimeError(error)
|
||||
|
||||
# Retryable error - continue with backoff
|
||||
last_error = error
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2**attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
result = await operation_func(server_url)
|
||||
return result, "", False
|
||||
|
||||
except ImportError:
|
||||
return (
|
||||
None,
|
||||
"MCP library not available. Please install with: pip install mcp",
|
||||
False,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
None,
|
||||
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Classify errors as retryable or non-retryable
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
return None, f"Authentication failed for MCP server: {e!s}", False
|
||||
if "connection" in error_str or "network" in error_str:
|
||||
return None, f"Network connection failed: {e!s}", True
|
||||
if "json" in error_str or "parsing" in error_str:
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP discovery error: {e!s}", False
|
||||
|
||||
async def _discover_mcp_tools_with_timeout(
|
||||
self, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Discover MCP tools with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
|
||||
"""Discover tools from MCP server with proper timeout handling."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection with timeout
|
||||
await asyncio.wait_for(
|
||||
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# List available tools with timeout
|
||||
tools_result = await asyncio.wait_for(
|
||||
session.list_tools(),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
|
||||
)
|
||||
|
||||
schemas = {}
|
||||
for tool in tools_result.tools:
|
||||
args_schema = None
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
sanitize_tool_name(tool.name), tool.inputSchema
|
||||
)
|
||||
|
||||
schemas[sanitize_tool_name(tool.name)] = {
|
||||
"description": getattr(tool, "description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
return schemas
|
||||
|
||||
def _json_schema_to_pydantic(
|
||||
self, tool_name: str, json_schema: dict[str, Any]
|
||||
) -> type:
|
||||
"""Convert JSON Schema to Pydantic model for tool arguments.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model naming)
|
||||
json_schema: JSON Schema dict with 'properties', 'required', etc.
|
||||
|
||||
Returns:
|
||||
Pydantic BaseModel class
|
||||
"""
|
||||
from pydantic import Field, create_model
|
||||
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
field_description = field_schema.get("description", "")
|
||||
|
||||
is_required = field_name in required_fields
|
||||
|
||||
if is_required:
|
||||
field_definitions[field_name] = (
|
||||
field_type,
|
||||
Field(..., description=field_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[field_name] = (
|
||||
field_type | None,
|
||||
Field(default=None, description=field_description),
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
|
||||
Args:
|
||||
field_schema: JSON Schema field definition
|
||||
|
||||
Returns:
|
||||
Python type
|
||||
"""
|
||||
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
else:
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_multimodal_tools() -> Sequence[BaseTool]:
|
||||
@@ -1264,10 +1811,10 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
start_time = time.time()
|
||||
matches = agent_memory.recall(formatted_messages, limit=5)
|
||||
matches = agent_memory.recall(formatted_messages, limit=10)
|
||||
memory_block = ""
|
||||
if matches:
|
||||
memory_block = "Relevant memories:\n" + "\n".join(
|
||||
memory_block = "Relevant memories:\n" + "\n\n".join(
|
||||
m.format() for m in matches
|
||||
)
|
||||
if memory_block:
|
||||
|
||||
@@ -4,8 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
import re
|
||||
from typing import Any, Final, Literal
|
||||
from typing import Any, Literal
|
||||
import uuid
|
||||
|
||||
from pydantic import (
|
||||
@@ -37,11 +36,6 @@ from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
|
||||
)
|
||||
|
||||
|
||||
PlatformApp = Literal[
|
||||
"asana",
|
||||
"box",
|
||||
@@ -203,7 +197,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
)
|
||||
mcps: list[str | MCPServerConfig] | None = Field(
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: Any = Field(
|
||||
default=None,
|
||||
@@ -282,16 +276,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith("https://"):
|
||||
validated_mcps.append(mcp)
|
||||
elif _SLUG_RE.match(mcp):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp!r}. "
|
||||
"String references must be an 'https://' URL or a valid "
|
||||
"slug (e.g. 'notion', 'notion#search', 'crewai-amp:notion')."
|
||||
f"Invalid MCP reference: {mcp}. "
|
||||
"String references must start with 'https://' or 'crewai-amp:'"
|
||||
)
|
||||
|
||||
elif isinstance(mcp, (MCPServerConfig)):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
|
||||
@@ -190,15 +190,6 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
|
||||
"""Get MCP server configurations for the given slugs."""
|
||||
return self._make_request(
|
||||
"GET",
|
||||
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
|
||||
params={"slugs": ",".join(slugs)},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_triggers(self) -> httpx.Response:
|
||||
"""Get all available triggers from integrations."""
|
||||
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.0"
|
||||
"crewai[tools]==1.9.3"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.0"
|
||||
"crewai[tools]==1.9.3"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.0"
|
||||
"crewai[tools]>=0.203.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -63,7 +63,6 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -166,7 +165,6 @@ __all__ = [
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MCPConfigFetchFailedEvent",
|
||||
"MCPConnectionCompletedEvent",
|
||||
"MCPConnectionFailedEvent",
|
||||
"MCPConnectionStartedEvent",
|
||||
|
||||
@@ -68,7 +68,6 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -666,16 +665,6 @@ class EventListener(BaseEventListener):
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConfigFetchFailedEvent)
|
||||
def on_mcp_config_fetch_failed(
|
||||
_: Any, event: MCPConfigFetchFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_config_fetch_failed(
|
||||
event.slug,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(
|
||||
_: Any, event: MCPToolExecutionStartedEvent
|
||||
|
||||
@@ -67,7 +67,6 @@ from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -182,5 +181,4 @@ EventTypes = (
|
||||
| MCPToolExecutionStartedEvent
|
||||
| MCPToolExecutionCompletedEvent
|
||||
| MCPToolExecutionFailedEvent
|
||||
| MCPConfigFetchFailedEvent
|
||||
)
|
||||
|
||||
@@ -83,16 +83,3 @@ class MCPToolExecutionFailedEvent(MCPEvent):
|
||||
error_type: str | None = None # "timeout", "validation", "server_error", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
|
||||
|
||||
class MCPConfigFetchFailedEvent(BaseEvent):
|
||||
"""Event emitted when fetching an AMP MCP server config fails.
|
||||
|
||||
This covers cases where the slug is not connected, the API call
|
||||
failed, or native MCP resolution failed after config was fetched.
|
||||
"""
|
||||
|
||||
type: str = "mcp_config_fetch_failed"
|
||||
slug: str
|
||||
error: str
|
||||
error_type: str | None = None # "not_connected", "api_error", "connection_failed"
|
||||
|
||||
@@ -1512,34 +1512,6 @@ To enable tracing, do any one of these:
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_config_fetch_failed(
|
||||
self,
|
||||
slug: str,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP config fetch failed event (AMP resolution failures)."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("MCP Config Fetch Failed\n\n", style="red bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{slug}\n", style="red")
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Config Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_started(
|
||||
self,
|
||||
server_name: str,
|
||||
|
||||
@@ -427,7 +427,7 @@ class LLM(BaseLLM):
|
||||
f"installed.\n\n"
|
||||
f"To fix this, either:\n"
|
||||
f" 1. Install LiteLLM for broad model support: "
|
||||
f"uv add 'crewai[litellm]'\n"
|
||||
f"uv add litellm\n"
|
||||
f"or\n"
|
||||
f"pip install litellm\n\n"
|
||||
f"For more details, see: "
|
||||
|
||||
@@ -18,7 +18,6 @@ from crewai.mcp.filters import (
|
||||
create_dynamic_tool_filter,
|
||||
create_static_tool_filter,
|
||||
)
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
@@ -29,7 +28,6 @@ __all__ = [
|
||||
"MCPServerHTTP",
|
||||
"MCPServerSSE",
|
||||
"MCPServerStdio",
|
||||
"MCPToolResolver",
|
||||
"StaticToolFilter",
|
||||
"ToolFilter",
|
||||
"ToolFilterContext",
|
||||
|
||||
@@ -6,7 +6,7 @@ from contextlib import AsyncExitStack
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -34,13 +34,6 @@ from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
class _MCPToolResult(NamedTuple):
|
||||
"""Internal result from an MCP tool call, carrying the ``isError`` flag."""
|
||||
|
||||
content: str
|
||||
is_error: bool
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
@@ -427,7 +420,6 @@ class MCPClient:
|
||||
return [
|
||||
{
|
||||
"name": sanitize_tool_name(tool.name),
|
||||
"original_name": tool.name,
|
||||
"description": getattr(tool, "description", ""),
|
||||
"inputSchema": getattr(tool, "inputSchema", {}),
|
||||
}
|
||||
@@ -469,46 +461,29 @@ class MCPClient:
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result: _MCPToolResult = await self._retry_operation(
|
||||
result = await self._retry_operation(
|
||||
lambda: self._call_tool_impl(tool_name, cleaned_arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
finished_at = datetime.now()
|
||||
execution_duration_ms = (finished_at - started_at).total_seconds() * 1000
|
||||
completed_at = datetime.now()
|
||||
execution_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
result=result,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
execution_duration_ms=execution_duration_ms,
|
||||
),
|
||||
)
|
||||
|
||||
if tool_result.is_error:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionFailedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
error=tool_result.content,
|
||||
error_type="tool_error",
|
||||
started_at=started_at,
|
||||
failed_at=finished_at,
|
||||
),
|
||||
)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
result=tool_result.content,
|
||||
started_at=started_at,
|
||||
completed_at=finished_at,
|
||||
execution_duration_ms=execution_duration_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return tool_result.content
|
||||
return result
|
||||
except Exception as e:
|
||||
failed_at = datetime.now()
|
||||
error_type = (
|
||||
@@ -589,27 +564,23 @@ class MCPClient:
|
||||
|
||||
return cleaned
|
||||
|
||||
async def _call_tool_impl(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> _MCPToolResult:
|
||||
async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Internal implementation of call_tool."""
|
||||
result = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
is_error = getattr(result, "isError", False) or False
|
||||
|
||||
# Extract result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return _MCPToolResult(str(content_item.text), is_error)
|
||||
return _MCPToolResult(str(content_item), is_error)
|
||||
return _MCPToolResult(str(result.content), is_error)
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
|
||||
return _MCPToolResult(str(result), is_error)
|
||||
return str(result)
|
||||
|
||||
async def list_prompts(self) -> list[dict[str, Any]]:
|
||||
"""List available prompts from MCP server.
|
||||
|
||||
@@ -1,592 +0,0 @@
|
||||
"""MCP tool resolution for CrewAI agents.
|
||||
|
||||
This module extracts all MCP-related tool resolution logic from the Agent class
|
||||
into a standalone MCPToolResolver. It handles three flavours of MCP reference:
|
||||
|
||||
1. Native configs: MCPServerStdio / MCPServerHTTP / MCPServerSSE objects.
|
||||
2. HTTPS URLs: e.g. "https://mcp.example.com/api"
|
||||
3. AMP references: e.g. "notion" or "notion#search" (legacy "crewai-amp:" prefix also works)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.mcp.client import MCPClient
|
||||
from crewai.mcp.config import (
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
MCP_CONNECTION_TIMEOUT: Final[int] = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
|
||||
|
||||
class MCPToolResolver:
|
||||
"""Resolves MCP server references / configs into CrewAI ``BaseTool`` instances.
|
||||
|
||||
Typical lifecycle::
|
||||
|
||||
resolver = MCPToolResolver(agent=my_agent, logger=my_agent._logger)
|
||||
tools = resolver.resolve(my_agent.mcps)
|
||||
# … agent executes tasks using *tools* …
|
||||
resolver.cleanup()
|
||||
|
||||
The resolver owns the MCP client connections it creates and is responsible
|
||||
for tearing them down via :meth:`cleanup`.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Any, logger: Logger) -> None:
|
||||
self._agent = agent
|
||||
self._logger = logger
|
||||
self._clients: list[Any] = []
|
||||
|
||||
@property
|
||||
def clients(self) -> list[Any]:
|
||||
return list(self._clients)
|
||||
|
||||
def resolve(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools."""
|
||||
all_tools: list[BaseTool] = []
|
||||
amp_refs: list[tuple[str, str | None]] = []
|
||||
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str) and mcp_config.startswith("https://"):
|
||||
all_tools.extend(self._resolve_external(mcp_config))
|
||||
elif isinstance(mcp_config, str):
|
||||
amp_refs.append(self._parse_amp_ref(mcp_config))
|
||||
else:
|
||||
tools, client = self._resolve_native(mcp_config)
|
||||
all_tools.extend(tools)
|
||||
if client:
|
||||
self._clients.append(client)
|
||||
|
||||
if amp_refs:
|
||||
tools, clients = self._resolve_amp(amp_refs)
|
||||
all_tools.extend(tools)
|
||||
self._clients.extend(clients)
|
||||
|
||||
return all_tools
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Disconnect all MCP client connections."""
|
||||
if not self._clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._clients.clear()
|
||||
|
||||
@staticmethod
|
||||
def _parse_amp_ref(mcp_config: str) -> tuple[str, str | None]:
|
||||
"""Parse an AMP reference into *(slug, optional tool name)*.
|
||||
|
||||
Accepts both bare slugs (``"notion"``, ``"notion#search"``) and the
|
||||
legacy ``"crewai-amp:notion"`` form.
|
||||
"""
|
||||
bare = mcp_config.removeprefix("crewai-amp:")
|
||||
slug, _, specific_tool = bare.partition("#")
|
||||
return slug, specific_tool or None
|
||||
|
||||
def _resolve_amp(
|
||||
self, amp_refs: list[tuple[str, str | None]]
|
||||
) -> tuple[list[BaseTool], list[Any]]:
|
||||
"""Fetch AMP configs in bulk and return their tools and clients.
|
||||
|
||||
Resolves each unique slug only once (single connection per server),
|
||||
then applies per-ref tool filters to select specific tools.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.mcp_events import MCPConfigFetchFailedEvent
|
||||
|
||||
unique_slugs = list(dict.fromkeys(slug for slug, _ in amp_refs))
|
||||
amp_configs_map = self._fetch_amp_mcp_configs(unique_slugs)
|
||||
|
||||
all_tools: list[BaseTool] = []
|
||||
all_clients: list[Any] = []
|
||||
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
|
||||
|
||||
for slug in unique_slugs:
|
||||
config_dict = amp_configs_map.get(slug)
|
||||
if not config_dict:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConfigFetchFailedEvent(
|
||||
slug=slug,
|
||||
error=f"Config for '{slug}' not found. Make sure it is connected in your account.",
|
||||
error_type="not_connected",
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
try:
|
||||
tools, client = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, client)
|
||||
if client:
|
||||
all_clients.append(client)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConfigFetchFailedEvent(
|
||||
slug=slug,
|
||||
error=str(e),
|
||||
error_type="connection_failed",
|
||||
),
|
||||
)
|
||||
|
||||
for slug, specific_tool in amp_refs:
|
||||
cached = resolved_cache.get(slug)
|
||||
if not cached:
|
||||
continue
|
||||
|
||||
slug_tools, _ = cached
|
||||
if specific_tool:
|
||||
all_tools.extend(
|
||||
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
|
||||
)
|
||||
else:
|
||||
all_tools.extend(slug_tools)
|
||||
|
||||
return all_tools, all_clients
|
||||
|
||||
def _fetch_amp_mcp_configs(self, slugs: list[str]) -> dict[str, dict[str, Any]]:
|
||||
"""Fetch MCP server configurations via CrewAI+ API.
|
||||
|
||||
Sends a GET request to the CrewAI+ mcps/configs endpoint with
|
||||
comma-separated slugs. CrewAI+ proxies the request to crewai-oauth.
|
||||
|
||||
API-level failures return ``{}``; individual slugs will then
|
||||
surface as ``MCPConfigFetchFailedEvent`` in :meth:`_resolve_amp`.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
|
||||
if response.status_code == 200:
|
||||
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
|
||||
return configs
|
||||
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Failed to fetch MCP configs: HTTP {response.status_code}",
|
||||
)
|
||||
return {}
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self._logger.log("debug", f"Failed to fetch MCP configs: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
self._logger.log("debug", f"Cannot fetch AMP MCP configs: {e}")
|
||||
return {}
|
||||
|
||||
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Resolve an HTTPS MCP server URL into tools."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
if "#" in mcp_ref:
|
||||
server_url, specific_tool = mcp_ref.split("#", 1)
|
||||
else:
|
||||
server_url, specific_tool = mcp_ref, None
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
|
||||
if not tool_schemas:
|
||||
self._logger.log(
|
||||
"warning", f"No tools discovered from MCP server: {server_url}"
|
||||
)
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params=server_params,
|
||||
tool_name=tool_name,
|
||||
tool_schema=schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(wrapper)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
|
||||
)
|
||||
continue
|
||||
|
||||
if specific_tool and not tools:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
|
||||
)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to connect to MCP server {server_url}: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self._agent,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
original_tool_name = tool_def.get("original_name", tool_name)
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
original_tool_name=original_tool_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _build_mcp_config_from_dict(
|
||||
config_dict: dict[str, Any],
|
||||
) -> MCPServerConfig:
|
||||
"""Convert a config dict from crewai-oauth into an MCPServerConfig."""
|
||||
config_type = config_dict.get("type", "http")
|
||||
|
||||
if config_type == "sse":
|
||||
return MCPServerSSE(
|
||||
url=config_dict["url"],
|
||||
headers=config_dict.get("headers"),
|
||||
cache_tools_list=config_dict.get("cache_tools_list", False),
|
||||
)
|
||||
|
||||
return MCPServerHTTP(
|
||||
url=config_dict["url"],
|
||||
headers=config_dict.get("headers"),
|
||||
streamable=config_dict.get("streamable", True),
|
||||
cache_tools_list=config_dict.get("cache_tools_list", False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_server_name(server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace(".", "_")
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
cache_key = server_url
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if current_time - cache_time < _cache_ttl:
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
_mcp_schema_cache[cache_key] = (schemas, current_time)
|
||||
return schemas
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
self._discover_mcp_tools_with_timeout, server_url
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery with exponential backoff."""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
result, error, should_retry = await self._attempt_mcp_discovery(
|
||||
operation_func, server_url
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if not should_retry:
|
||||
raise RuntimeError(error)
|
||||
|
||||
last_error = error
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2**attempt
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery; returns *(result, error_message, should_retry)*."""
|
||||
try:
|
||||
result = await operation_func(server_url)
|
||||
return result, "", False
|
||||
|
||||
except ImportError:
|
||||
return (
|
||||
None,
|
||||
"MCP library not available. Please install with: pip install mcp",
|
||||
False,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
None,
|
||||
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
return None, f"Authentication failed for MCP server: {e!s}", False
|
||||
if "connection" in error_str or "network" in error_str:
|
||||
return None, f"Network connection failed: {e!s}", True
|
||||
if "json" in error_str or "parsing" in error_str:
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP discovery error: {e!s}", False
|
||||
|
||||
async def _discover_mcp_tools_with_timeout(
|
||||
self, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Discover MCP tools with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
|
||||
"""Discover tools from an MCP server (HTTPS / streamable-HTTP path)."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await asyncio.wait_for(
|
||||
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
tools_result = await asyncio.wait_for(
|
||||
session.list_tools(),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
|
||||
)
|
||||
|
||||
schemas = {}
|
||||
for tool in tools_result.tools:
|
||||
args_schema = None
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
sanitize_tool_name(tool.name), tool.inputSchema
|
||||
)
|
||||
|
||||
schemas[sanitize_tool_name(tool.name)] = {
|
||||
"description": getattr(tool, "description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
return schemas
|
||||
|
||||
@staticmethod
|
||||
def _json_schema_to_pydantic(tool_name: str, json_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema to a Pydantic model for tool arguments."""
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model_from_schema(
|
||||
json_schema,
|
||||
model_name=model_name,
|
||||
enrich_descriptions=True,
|
||||
)
|
||||
@@ -1,6 +1,14 @@
|
||||
"""Memory module: unified Memory with LLM analysis and pluggable storage."""
|
||||
"""Memory module: unified Memory with LLM analysis and pluggable storage.
|
||||
|
||||
Heavy dependencies are lazily imported so that
|
||||
``import crewai`` does not initialise at runtime — critical for
|
||||
Celery pre-fork and similar deployment patterns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.encoding_flow import EncodingFlow
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.types import (
|
||||
MemoryMatch,
|
||||
@@ -10,7 +18,24 @@ from crewai.memory.types import (
|
||||
embed_text,
|
||||
embed_texts,
|
||||
)
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"Memory": ("crewai.memory.unified_memory", "Memory"),
|
||||
"EncodingFlow": ("crewai.memory.encoding_flow", "EncodingFlow"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazily import Memory / EncodingFlow to avoid pulling in lancedb at import time."""
|
||||
if name in _LAZY_IMPORTS:
|
||||
import importlib
|
||||
|
||||
module_path, attr = _LAZY_IMPORTS[name]
|
||||
mod = importlib.import_module(module_path)
|
||||
val = getattr(mod, attr)
|
||||
globals()[name] = val
|
||||
return val
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -99,8 +99,12 @@ class MemoryMatch(BaseModel):
|
||||
lines.append(f" categories: {', '.join(self.record.categories)}")
|
||||
if self.record.metadata:
|
||||
for key, value in self.record.metadata.items():
|
||||
if value is not None:
|
||||
lines.append(f" {key}: {value}")
|
||||
if value:
|
||||
if isinstance(value, list):
|
||||
rendered_value = ", ".join(str(item) for item in value)
|
||||
else:
|
||||
rendered_value = str(value)
|
||||
lines.append(f" {key}: {rendered_value}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -307,7 +311,7 @@ def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
return []
|
||||
first = result[0]
|
||||
if hasattr(first, "tolist"):
|
||||
return list(first.tolist())
|
||||
return first.tolist()
|
||||
if isinstance(first, list):
|
||||
return [float(x) for x in first]
|
||||
return list(first)
|
||||
|
||||
@@ -6,7 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -21,7 +21,6 @@ from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.analyze import extract_memories_from_content
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.types import (
|
||||
MemoryConfig,
|
||||
MemoryMatch,
|
||||
@@ -30,20 +29,13 @@ from crewai.memory.types import (
|
||||
compute_composite_score,
|
||||
embed_text,
|
||||
)
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedder() -> OpenAIEmbeddingFunction:
|
||||
def _default_embedder() -> Any:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
|
||||
return build_embedder(spec)
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
return build_embedder({"provider": "openai", "config": {}})
|
||||
|
||||
|
||||
class Memory:
|
||||
@@ -143,17 +135,13 @@ class Memory:
|
||||
self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm
|
||||
self._embedder_config: Any = embedder
|
||||
self._embedder_instance: Any = (
|
||||
embedder
|
||||
if (embedder is not None and not isinstance(embedder, dict))
|
||||
else None
|
||||
embedder if (embedder is not None and not isinstance(embedder, dict)) else None
|
||||
)
|
||||
|
||||
# Storage is initialized eagerly (local, no API key needed).
|
||||
self._storage: StorageBackend
|
||||
if storage == "lancedb":
|
||||
self._storage = LanceDBStorage()
|
||||
elif isinstance(storage, str):
|
||||
self._storage = LanceDBStorage(path=storage)
|
||||
if isinstance(storage, str):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
|
||||
else:
|
||||
self._storage = storage
|
||||
|
||||
@@ -176,17 +164,12 @@ class Memory:
|
||||
from crewai.llm import LLM
|
||||
|
||||
try:
|
||||
model_name = (
|
||||
self._llm_config
|
||||
if isinstance(self._llm_config, str)
|
||||
else str(self._llm_config)
|
||||
)
|
||||
self._llm_instance = LLM(model=model_name)
|
||||
self._llm_instance = LLM(model=self._llm_config)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Memory requires an LLM for analysis but initialization failed: {e}\n\n"
|
||||
"To fix this, do one of the following:\n"
|
||||
" - Set OPENAI_API_KEY for the default model (gpt-4o-mini)\n"
|
||||
' - Set OPENAI_API_KEY for the default model (gpt-4o-mini)\n'
|
||||
' - Pass a different model: Memory(llm="anthropic/claude-3-haiku-20240307")\n'
|
||||
' - Pass any LLM instance: Memory(llm=LLM(model="your-model"))\n'
|
||||
" - To skip LLM analysis, pass all fields explicitly to remember()\n"
|
||||
@@ -201,6 +184,8 @@ class Memory:
|
||||
if self._embedder_instance is None:
|
||||
try:
|
||||
if isinstance(self._embedder_config, dict):
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
self._embedder_instance = build_embedder(self._embedder_config)
|
||||
else:
|
||||
self._embedder_instance = _default_embedder()
|
||||
@@ -336,7 +321,7 @@ class Memory:
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
agent_role: str | None = None,
|
||||
) -> MemoryRecord | None:
|
||||
) -> MemoryRecord:
|
||||
"""Store a single item in memory (synchronous).
|
||||
|
||||
Routes through the same serialized save pool as ``remember_many``
|
||||
@@ -360,7 +345,7 @@ class Memory:
|
||||
Exception: On save failure (events emitted).
|
||||
"""
|
||||
if self._read_only:
|
||||
return None
|
||||
return None # type: ignore[return-value]
|
||||
_source_type = "unified_memory"
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
@@ -377,13 +362,7 @@ class Memory:
|
||||
# then immediately wait for the result.
|
||||
future = self._submit_save(
|
||||
self._encode_batch,
|
||||
[content],
|
||||
scope,
|
||||
categories,
|
||||
metadata,
|
||||
importance,
|
||||
source,
|
||||
private,
|
||||
[content], scope, categories, metadata, importance, source, private,
|
||||
)
|
||||
records = future.result()
|
||||
record = records[0] if records else None
|
||||
@@ -452,14 +431,8 @@ class Memory:
|
||||
|
||||
self._submit_save(
|
||||
self._background_encode_batch,
|
||||
contents,
|
||||
scope,
|
||||
categories,
|
||||
metadata,
|
||||
importance,
|
||||
source,
|
||||
private,
|
||||
agent_role,
|
||||
contents, scope, categories, metadata,
|
||||
importance, source, private, agent_role,
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -599,13 +572,14 @@ class Memory:
|
||||
# Privacy filter
|
||||
if not include_private:
|
||||
raw = [
|
||||
(r, s)
|
||||
for r, s in raw
|
||||
(r, s) for r, s in raw
|
||||
if not r.private or r.source == source
|
||||
]
|
||||
results = []
|
||||
for r, s in raw:
|
||||
composite, reasons = compute_composite_score(r, s, self._config)
|
||||
composite, reasons = compute_composite_score(
|
||||
r, s, self._config
|
||||
)
|
||||
results.append(
|
||||
MemoryMatch(
|
||||
record=r,
|
||||
@@ -771,9 +745,7 @@ class Memory:
|
||||
limit: Maximum number of records to return.
|
||||
offset: Number of records to skip (for pagination).
|
||||
"""
|
||||
return self._storage.list_records(
|
||||
scope_prefix=scope, limit=limit, offset=offset
|
||||
)
|
||||
return self._storage.list_records(scope_prefix=scope, limit=limit, offset=offset)
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Return scope info for path."""
|
||||
@@ -815,7 +787,7 @@ class Memory:
|
||||
importance: float | None = None,
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
) -> MemoryRecord | None:
|
||||
) -> MemoryRecord:
|
||||
"""Async remember: delegates to sync for now."""
|
||||
return self.remember(
|
||||
content,
|
||||
|
||||
@@ -216,10 +216,6 @@ def build_embedder_from_dict(
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec): # type: ignore[no-untyped-def]
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
@@ -345,10 +341,6 @@ def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
def build_embedder(spec): # type: ignore[no-untyped-def]
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import (
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationError,
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
@@ -162,7 +163,7 @@ class BaseTool(BaseModel, ABC):
|
||||
Raises:
|
||||
ValueError: If validation against args_schema fails.
|
||||
"""
|
||||
if self.args_schema is not None and self.args_schema.model_fields:
|
||||
if kwargs and self.args_schema is not None and self.args_schema.model_fields:
|
||||
try:
|
||||
validated = self.args_schema.model_validate(kwargs)
|
||||
return validated.model_dump()
|
||||
@@ -177,8 +178,7 @@ class BaseTool(BaseModel, ABC):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
|
||||
result = self._run(*args, **kwargs)
|
||||
|
||||
@@ -203,8 +203,7 @@ class BaseTool(BaseModel, ABC):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
@@ -357,8 +356,7 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
@@ -390,8 +388,7 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
@@ -27,16 +27,14 @@ class MCPNativeTool(BaseTool):
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
original_tool_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
tool_name: Name of the tool (may be prefixed).
|
||||
tool_name: Original name of the tool on the MCP server.
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
@@ -59,7 +57,7 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._original_tool_name = tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class RecallMemoryTool(BaseTool):
|
||||
|
||||
if not all_lines:
|
||||
return "No relevant memories found."
|
||||
return "Found memories:\n" + "\n".join(all_lines)
|
||||
return "Found memories:\n" + "\n\n".join(all_lines)
|
||||
|
||||
|
||||
class RememberSchema(BaseModel):
|
||||
|
||||
@@ -168,9 +168,7 @@ def convert_tools_to_openai_schema(
|
||||
parameters: dict[str, Any] = {}
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
try:
|
||||
schema_output = generate_model_description(
|
||||
tool.args_schema, strip_null_types=False
|
||||
)
|
||||
schema_output = generate_model_description(tool.args_schema)
|
||||
parameters = schema_output.get("json_schema", {}).get("schema", {})
|
||||
# Remove title and description from schema root as they're redundant
|
||||
parameters.pop("title", None)
|
||||
|
||||
@@ -417,11 +417,7 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
return schema
|
||||
|
||||
|
||||
def generate_model_description(
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
strip_null_types: bool = True,
|
||||
) -> ModelDescription:
|
||||
def generate_model_description(model: type[BaseModel]) -> ModelDescription:
|
||||
"""Generate JSON schema description of a Pydantic model.
|
||||
|
||||
This function takes a Pydantic model class and returns its JSON schema,
|
||||
@@ -430,9 +426,6 @@ def generate_model_description(
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
strip_null_types: When ``True`` (default), remove ``null`` from
|
||||
``anyOf`` / ``type`` arrays. Set to ``False`` to allow sending ``null`` for
|
||||
optional fields.
|
||||
|
||||
Returns:
|
||||
A ModelDescription with JSON schema representation of the model.
|
||||
@@ -449,9 +442,7 @@ def generate_model_description(
|
||||
json_schema = fix_discriminator_mappings(json_schema)
|
||||
json_schema = convert_oneof_to_anyof(json_schema)
|
||||
json_schema = ensure_all_properties_required(json_schema)
|
||||
|
||||
if strip_null_types:
|
||||
json_schema = strip_null_from_types(json_schema)
|
||||
json_schema = strip_null_from_types(json_schema)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
@@ -491,66 +482,10 @@ FORMAT_TYPE_MAP: dict[str, type[Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def build_rich_field_description(prop_schema: dict[str, Any]) -> str:
|
||||
"""Build a comprehensive field description including constraints.
|
||||
|
||||
Embeds format, enum, pattern, min/max, and example constraints into the
|
||||
description text so that LLMs can understand tool parameter requirements
|
||||
without inspecting the raw JSON Schema.
|
||||
|
||||
Args:
|
||||
prop_schema: Property schema with description and constraints.
|
||||
|
||||
Returns:
|
||||
Enhanced description with format, enum, and other constraints.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
description = prop_schema.get("description", "")
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
format_type = prop_schema.get("format")
|
||||
if format_type:
|
||||
parts.append(f"Format: {format_type}")
|
||||
|
||||
enum_values = prop_schema.get("enum")
|
||||
if enum_values:
|
||||
enum_str = ", ".join(repr(v) for v in enum_values)
|
||||
parts.append(f"Allowed values: [{enum_str}]")
|
||||
|
||||
pattern = prop_schema.get("pattern")
|
||||
if pattern:
|
||||
parts.append(f"Pattern: {pattern}")
|
||||
|
||||
minimum = prop_schema.get("minimum")
|
||||
maximum = prop_schema.get("maximum")
|
||||
if minimum is not None:
|
||||
parts.append(f"Minimum: {minimum}")
|
||||
if maximum is not None:
|
||||
parts.append(f"Maximum: {maximum}")
|
||||
|
||||
min_length = prop_schema.get("minLength")
|
||||
max_length = prop_schema.get("maxLength")
|
||||
if min_length is not None:
|
||||
parts.append(f"Min length: {min_length}")
|
||||
if max_length is not None:
|
||||
parts.append(f"Max length: {max_length}")
|
||||
|
||||
examples = prop_schema.get("examples")
|
||||
if examples:
|
||||
examples_str = ", ".join(repr(e) for e in examples[:3])
|
||||
parts.append(f"Examples: {examples_str}")
|
||||
|
||||
return ". ".join(parts) if parts else ""
|
||||
|
||||
|
||||
def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: dict[str, Any],
|
||||
*,
|
||||
root_schema: dict[str, Any] | None = None,
|
||||
model_name: str | None = None,
|
||||
enrich_descriptions: bool = False,
|
||||
__config__: ConfigDict | None = None,
|
||||
__base__: type[BaseModel] | None = None,
|
||||
__module__: str = __name__,
|
||||
@@ -568,13 +503,6 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: A dictionary representing the JSON schema.
|
||||
root_schema: The root schema containing $defs. If not provided, the
|
||||
current schema is treated as the root schema.
|
||||
model_name: Override for the model name. If not provided, the schema
|
||||
``title`` field is used, falling back to ``"DynamicModel"``.
|
||||
enrich_descriptions: When True, augment field descriptions with
|
||||
constraint info (format, enum, pattern, min/max, examples) via
|
||||
:func:`build_rich_field_description`. Useful for LLM-facing tool
|
||||
schemas where constraints in the description help the model
|
||||
understand parameter requirements.
|
||||
__config__: Pydantic configuration for the generated model.
|
||||
__base__: Base class for the generated model. Defaults to BaseModel.
|
||||
__module__: Module name for the generated model class.
|
||||
@@ -611,14 +539,10 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
if "title" not in json_schema and "title" in (root_schema or {}):
|
||||
json_schema["title"] = (root_schema or {}).get("title")
|
||||
|
||||
effective_name = model_name or json_schema.get("title") or "DynamicModel"
|
||||
model_name = json_schema.get("title") or "DynamicModel"
|
||||
field_definitions = {
|
||||
name: _json_schema_to_pydantic_field(
|
||||
name,
|
||||
prop,
|
||||
json_schema.get("required", []),
|
||||
effective_root,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
name, prop, json_schema.get("required", []), effective_root
|
||||
)
|
||||
for name, prop in (json_schema.get("properties", {}) or {}).items()
|
||||
}
|
||||
@@ -626,7 +550,7 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
effective_config = __config__ or ConfigDict(extra="forbid")
|
||||
|
||||
return create_model_base(
|
||||
effective_name,
|
||||
model_name,
|
||||
__config__=effective_config,
|
||||
__base__=__base__,
|
||||
__module__=__module__,
|
||||
@@ -641,8 +565,6 @@ def _json_schema_to_pydantic_field(
|
||||
json_schema: dict[str, Any],
|
||||
required: list[str],
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
enrich_descriptions: bool = False,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema property to a Pydantic field definition.
|
||||
|
||||
@@ -651,29 +573,20 @@ def _json_schema_to_pydantic_field(
|
||||
json_schema: The JSON schema for this field.
|
||||
required: List of required field names.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
enrich_descriptions: When True, embed constraints in the description.
|
||||
|
||||
Returns:
|
||||
A tuple of (type, Field) for use with create_model.
|
||||
"""
|
||||
type_ = _json_schema_to_pydantic_type(
|
||||
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
|
||||
)
|
||||
type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title())
|
||||
description = json_schema.get("description")
|
||||
examples = json_schema.get("examples")
|
||||
is_required = name in required
|
||||
|
||||
field_params: dict[str, Any] = {}
|
||||
schema_extra: dict[str, Any] = {}
|
||||
|
||||
if enrich_descriptions:
|
||||
rich_desc = build_rich_field_description(json_schema)
|
||||
if rich_desc:
|
||||
field_params["description"] = rich_desc
|
||||
else:
|
||||
description = json_schema.get("description")
|
||||
if description:
|
||||
field_params["description"] = description
|
||||
|
||||
examples = json_schema.get("examples")
|
||||
if description:
|
||||
field_params["description"] = description
|
||||
if examples:
|
||||
schema_extra["examples"] = examples
|
||||
|
||||
@@ -789,7 +702,6 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
name_: str | None = None,
|
||||
enrich_descriptions: bool = False,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema to a Python/Pydantic type.
|
||||
|
||||
@@ -797,7 +709,6 @@ def _json_schema_to_pydantic_type(
|
||||
json_schema: The JSON schema to convert.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
name_: Optional name for nested models.
|
||||
enrich_descriptions: Propagated to nested model creation.
|
||||
|
||||
Returns:
|
||||
A Python type corresponding to the JSON schema.
|
||||
@@ -805,9 +716,7 @@ def _json_schema_to_pydantic_type(
|
||||
ref = json_schema.get("$ref")
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
|
||||
)
|
||||
return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
if enum_values:
|
||||
@@ -822,10 +731,7 @@ def _json_schema_to_pydantic_type(
|
||||
if any_of_schemas:
|
||||
any_of_types = [
|
||||
_json_schema_to_pydantic_type(
|
||||
schema,
|
||||
root_schema,
|
||||
name_=f"{name_ or 'Union'}Option{i}",
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
schema, root_schema, name_=f"{name_ or 'Union'}Option{i}"
|
||||
)
|
||||
for i, schema in enumerate(any_of_schemas)
|
||||
]
|
||||
@@ -835,14 +741,10 @@ def _json_schema_to_pydantic_type(
|
||||
if all_of_schemas:
|
||||
if len(all_of_schemas) == 1:
|
||||
return _json_schema_to_pydantic_type(
|
||||
all_of_schemas[0], root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
all_of_schemas[0], root_schema, name_=name_
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
merged, root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return _json_schema_to_pydantic_type(merged, root_schema, name_=name_)
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
|
||||
@@ -858,8 +760,7 @@ def _json_schema_to_pydantic_type(
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = _json_schema_to_pydantic_type(
|
||||
items_schema, root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
items_schema, root_schema, name_=name_
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
return list
|
||||
@@ -869,10 +770,7 @@ def _json_schema_to_pydantic_type(
|
||||
json_schema_ = json_schema.copy()
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_ or "DynamicModel"
|
||||
return create_model_from_schema(
|
||||
json_schema_, root_schema=root_schema,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return create_model_from_schema(json_schema_, root_schema=root_schema)
|
||||
return dict
|
||||
if type_ == "null":
|
||||
return None
|
||||
|
||||
@@ -659,7 +659,7 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post):
|
||||
|
||||
|
||||
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
|
||||
@patch("crewai.agent.Agent.get_mcp_tools")
|
||||
@patch("crewai.agent.Agent._get_external_mcp_tools")
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
|
||||
@@ -691,7 +691,7 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
assert result.raw is not None
|
||||
|
||||
# Verify MCP tools were retrieved
|
||||
mock_get_mcp_tools.assert_called_once_with(["https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research"])
|
||||
mock_get_mcp_tools.assert_called_once_with("https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1136,7 +1136,6 @@ def test_lite_agent_memory_instance_recall_and_save_called():
|
||||
successful_requests=1,
|
||||
)
|
||||
mock_memory = Mock()
|
||||
mock_memory._read_only = False
|
||||
mock_memory.recall.return_value = []
|
||||
mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."]
|
||||
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Tests for AMP MCP config fetching and tool resolution."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(agent):
|
||||
return MCPToolResolver(agent=agent, logger=agent._logger)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_definitions():
|
||||
return [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create_page",
|
||||
"description": "Create a page",
|
||||
"inputSchema": {},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestBuildMCPConfigFromDict:
|
||||
def test_builds_http_config(self):
|
||||
config_dict = {
|
||||
"type": "http",
|
||||
"url": "https://mcp.example.com/api",
|
||||
"headers": {"Authorization": "Bearer token123"},
|
||||
"streamable": True,
|
||||
"cache_tools_list": False,
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerHTTP)
|
||||
assert result.url == "https://mcp.example.com/api"
|
||||
assert result.headers == {"Authorization": "Bearer token123"}
|
||||
assert result.streamable is True
|
||||
assert result.cache_tools_list is False
|
||||
|
||||
def test_builds_sse_config(self):
|
||||
config_dict = {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.example.com/sse",
|
||||
"headers": {"Authorization": "Bearer token123"},
|
||||
"cache_tools_list": True,
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerSSE)
|
||||
assert result.url == "https://mcp.example.com/sse"
|
||||
assert result.headers == {"Authorization": "Bearer token123"}
|
||||
assert result.cache_tools_list is True
|
||||
|
||||
def test_defaults_to_http(self):
|
||||
config_dict = {
|
||||
"url": "https://mcp.example.com/api",
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerHTTP)
|
||||
assert result.streamable is True
|
||||
|
||||
def test_http_defaults(self):
|
||||
config_dict = {
|
||||
"type": "http",
|
||||
"url": "https://mcp.example.com/api",
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert result.headers is None
|
||||
assert result.streamable is True
|
||||
assert result.cache_tools_list is False
|
||||
|
||||
|
||||
class TestFetchAmpMCPConfigs:
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"configs": {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer notion-token"},
|
||||
},
|
||||
"github": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.github.com/api",
|
||||
"headers": {"Authorization": "Bearer gh-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion", "github"])
|
||||
|
||||
assert "notion" in result
|
||||
assert "github" in result
|
||||
assert result["notion"]["url"] == "https://mcp.notion.so/sse"
|
||||
mock_plus_api_class.assert_called_once_with(api_key="test-api-key")
|
||||
mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"])
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"configs": {"notion": {"type": "sse", "url": "https://mcp.notion.so/sse"}},
|
||||
}
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion", "missing-server"])
|
||||
|
||||
assert "notion" in result
|
||||
assert "missing-server" not in result
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
import httpx
|
||||
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.side_effect = httpx.ConnectError("Connection refused")
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", side_effect=Exception("No token"))
|
||||
def test_returns_empty_when_no_token(self, mock_get_token, resolver):
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestParseAmpRef:
|
||||
def test_bare_slug(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_bare_slug_with_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion#search")
|
||||
assert slug == "notion"
|
||||
assert tool == "search"
|
||||
|
||||
def test_bare_slug_with_empty_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion#")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_legacy_prefix_slug(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_legacy_prefix_with_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion#search")
|
||||
assert slug == "notion"
|
||||
assert tool == "search"
|
||||
|
||||
|
||||
class TestGetMCPToolsAmpIntegration:
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_single_request_for_multiple_amp_refs(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
"github": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.github.com/api",
|
||||
"headers": {"Authorization": "Bearer gh-token"},
|
||||
"streamable": True,
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion", "github"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion", "github"])
|
||||
assert len(tools) == 4 # 2 tools per server
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_tool_filter_with_hash_syntax(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#search"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "mcp_notion_so_sse_search"
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_deduplicates_slugs(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#search", "notion#create_page"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_skips_missing_configs_gracefully(self, mock_fetch, agent):
|
||||
mock_fetch.return_value = {}
|
||||
|
||||
tools = agent.get_mcp_tools(["missing-server"])
|
||||
|
||||
assert tools == []
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_legacy_crewai_amp_prefix_still_works(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["crewai-amp:notion"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
@patch.object(MCPToolResolver, "_resolve_external")
|
||||
def test_non_amp_items_unaffected(
|
||||
self,
|
||||
mock_external,
|
||||
mock_fetch,
|
||||
mock_client_class,
|
||||
agent,
|
||||
mock_tool_definitions,
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_external_tool = MagicMock(spec=BaseTool)
|
||||
mock_external.return_value = [mock_external_tool]
|
||||
|
||||
http_config = MCPServerHTTP(
|
||||
url="https://other.mcp.com/api",
|
||||
headers={"Authorization": "Bearer other"},
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools(
|
||||
[
|
||||
"notion",
|
||||
"https://external.mcp.com/api",
|
||||
http_config,
|
||||
]
|
||||
)
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
mock_external.assert_called_once_with("https://external.mcp.com/api")
|
||||
# 2 from notion + 1 from external + 2 from http_config
|
||||
assert len(tools) == 5
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent.core import Agent
|
||||
@@ -46,7 +46,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
@@ -82,7 +82,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
@@ -117,7 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
mcps=[sse_config],
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
@@ -141,7 +141,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in synchronous context (normal crew execution)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
@@ -173,7 +173,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in async context (e.g., from a Flow)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
|
||||
@@ -319,7 +319,6 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
||||
from crewai.agents.parser import AgentFinish
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory._read_only = False
|
||||
mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
@@ -360,7 +359,6 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory._read_only = False
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.memory = mock_memory
|
||||
mock_agent._logger = MagicMock()
|
||||
|
||||
@@ -268,13 +268,6 @@ class TestBaseToolRunValidation:
|
||||
result = t.run(code="console.log('hi')", language="javascript")
|
||||
assert result == "Executed javascript: console.log('hi')"
|
||||
|
||||
def test_run_with_no_args_raises_validation_error(self) -> None:
|
||||
"""Calling run() with no arguments should raise a clear ValueError,
|
||||
not a cryptic TypeError about missing positional arguments (GH-4611)."""
|
||||
t = CodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
t.run()
|
||||
|
||||
def test_run_with_missing_required_kwarg_raises(self) -> None:
|
||||
"""Missing required kwargs should raise ValueError from schema validation."""
|
||||
t = CodeExecutorTool()
|
||||
@@ -385,13 +378,6 @@ class TestBaseToolArunValidation:
|
||||
result = await t.arun(code="print('hello')")
|
||||
assert result == "Async executed python: print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_no_args_raises_validation_error(self) -> None:
|
||||
"""Calling arun() with no arguments should raise a clear ValueError (GH-4611)."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
await t.arun()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_missing_required_kwarg_raises(self) -> None:
|
||||
"""Missing required kwargs should raise ValueError in arun."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -235,79 +235,6 @@ def _make_mock_i18n() -> MagicMock:
|
||||
}.get(key, "")
|
||||
return mock_i18n
|
||||
|
||||
class MCPStyleInput(BaseModel):
|
||||
"""Input schema mimicking an MCP tool with optional fields."""
|
||||
|
||||
query: str = Field(description="Search query")
|
||||
filter_type: Optional[Literal["internal", "user"]] = Field(
|
||||
default=None, description="Filter type"
|
||||
)
|
||||
page_id: Optional[str] = Field(
|
||||
default=None, description="Page UUID"
|
||||
)
|
||||
|
||||
|
||||
class MCPStyleTool(BaseTool):
|
||||
"""A tool mimicking MCP tool schemas with optional fields."""
|
||||
|
||||
name: str = "mcp_search"
|
||||
description: str = "Search with optional filters"
|
||||
args_schema: type[BaseModel] = MCPStyleInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
return "result"
|
||||
|
||||
|
||||
class TestOptionalFieldsPreserveNull:
|
||||
"""Tests that optional tool fields preserve null in the schema."""
|
||||
|
||||
def test_optional_string_allows_null(self) -> None:
|
||||
"""Optional[str] fields should include null in the schema so the LLM
|
||||
can send null instead of being forced to guess a value."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
page_id_prop = params["properties"]["page_id"]
|
||||
|
||||
assert "anyOf" in page_id_prop
|
||||
type_options = [opt.get("type") for opt in page_id_prop["anyOf"]]
|
||||
assert "string" in type_options
|
||||
assert "null" in type_options
|
||||
|
||||
def test_optional_literal_allows_null(self) -> None:
|
||||
"""Optional[Literal[...]] fields should include null."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
filter_prop = params["properties"]["filter_type"]
|
||||
|
||||
assert "anyOf" in filter_prop
|
||||
has_null = any(opt.get("type") == "null" for opt in filter_prop["anyOf"])
|
||||
assert has_null
|
||||
|
||||
def test_required_field_stays_non_null(self) -> None:
|
||||
"""Required fields without Optional should NOT have null."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
query_prop = params["properties"]["query"]
|
||||
|
||||
assert query_prop.get("type") == "string"
|
||||
assert "anyOf" not in query_prop
|
||||
|
||||
def test_all_fields_in_required_for_strict_mode(self) -> None:
|
||||
"""All fields (including optional) must be in required for strict mode."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
assert "query" in params["required"]
|
||||
assert "filter_type" in params["required"]
|
||||
assert "page_id" in params["required"]
|
||||
|
||||
|
||||
class TestSummarizeMessages:
|
||||
"""Tests for summarize_messages function."""
|
||||
|
||||
@@ -1,884 +0,0 @@
|
||||
"""Tests for pydantic_schema_utils module.
|
||||
|
||||
Covers:
|
||||
- create_model_from_schema: type mapping, required/optional, enums, formats,
|
||||
nested objects, arrays, unions, allOf, $ref, model_name, enrich_descriptions
|
||||
- Schema transformation helpers: resolve_refs, force_additional_properties_false,
|
||||
strip_unsupported_formats, ensure_type_in_schemas, convert_oneof_to_anyof,
|
||||
ensure_all_properties_required, strip_null_from_types, build_rich_field_description
|
||||
- End-to-end MCP tool schema conversion
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
build_rich_field_description,
|
||||
convert_oneof_to_anyof,
|
||||
create_model_from_schema,
|
||||
ensure_all_properties_required,
|
||||
ensure_type_in_schemas,
|
||||
force_additional_properties_false,
|
||||
resolve_refs,
|
||||
strip_null_from_types,
|
||||
strip_unsupported_formats,
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleTypes:
|
||||
def test_string_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(name="Alice")
|
||||
assert obj.name == "Alice"
|
||||
|
||||
def test_integer_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
"required": ["count"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(count=42)
|
||||
assert obj.count == 42
|
||||
|
||||
def test_number_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"score": {"type": "number"}},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(score=3.14)
|
||||
assert obj.score == pytest.approx(3.14)
|
||||
|
||||
def test_boolean_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"active": {"type": "boolean"}},
|
||||
"required": ["active"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(active=True).active is True
|
||||
|
||||
def test_null_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"value": {"type": "null"}},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(value=None)
|
||||
assert obj.value is None
|
||||
|
||||
|
||||
class TestRequiredOptional:
|
||||
def test_required_field_has_no_default(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
with pytest.raises(Exception):
|
||||
Model()
|
||||
|
||||
def test_optional_field_defaults_to_none(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": [],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj.name is None
|
||||
|
||||
def test_mixed_required_optional(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"label": {"type": "string"},
|
||||
},
|
||||
"required": ["id"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(id=1)
|
||||
assert obj.id == 1
|
||||
assert obj.label is None
|
||||
|
||||
|
||||
class TestEnumLiteral:
|
||||
def test_string_enum(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"color": {"type": "string", "enum": ["red", "green", "blue"]},
|
||||
},
|
||||
"required": ["color"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(color="red")
|
||||
assert obj.color == "red"
|
||||
|
||||
def test_string_enum_rejects_invalid(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"color": {"type": "string", "enum": ["red", "green", "blue"]},
|
||||
},
|
||||
"required": ["color"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
with pytest.raises(Exception):
|
||||
Model(color="yellow")
|
||||
|
||||
def test_const_value(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {"const": "fixed"},
|
||||
},
|
||||
"required": ["kind"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(kind="fixed")
|
||||
assert obj.kind == "fixed"
|
||||
|
||||
|
||||
class TestFormatMapping:
|
||||
def test_date_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"},
|
||||
},
|
||||
"required": ["birthday"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(birthday=datetime.date(2000, 1, 15))
|
||||
assert obj.birthday == datetime.date(2000, 1, 15)
|
||||
|
||||
def test_datetime_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {"type": "string", "format": "date-time"},
|
||||
},
|
||||
"required": ["created_at"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
dt = datetime.datetime(2025, 6, 1, 12, 0, 0)
|
||||
obj = Model(created_at=dt)
|
||||
assert obj.created_at == dt
|
||||
|
||||
def test_time_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"alarm": {"type": "string", "format": "time"},
|
||||
},
|
||||
"required": ["alarm"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
t = datetime.time(8, 30)
|
||||
obj = Model(alarm=t)
|
||||
assert obj.alarm == t
|
||||
|
||||
|
||||
class TestNestedObjects:
|
||||
def test_nested_object_creates_model(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
"required": ["street", "city"],
|
||||
},
|
||||
},
|
||||
"required": ["address"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(address={"street": "123 Main", "city": "Springfield"})
|
||||
assert obj.address.street == "123 Main"
|
||||
assert obj.address.city == "Springfield"
|
||||
|
||||
def test_object_without_properties_returns_dict(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {"type": "object"},
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(metadata={"key": "value"})
|
||||
assert obj.metadata == {"key": "value"}
|
||||
|
||||
|
||||
class TestTypedArrays:
|
||||
def test_array_of_strings(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["tags"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(tags=["a", "b", "c"])
|
||||
assert obj.tags == ["a", "b", "c"]
|
||||
|
||||
def test_array_of_objects(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["items"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(items=[{"id": 1}, {"id": 2}])
|
||||
assert len(obj.items) == 2
|
||||
assert obj.items[0].id == 1
|
||||
|
||||
def test_untyped_array(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "array"}},
|
||||
"required": ["data"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(data=[1, "two", 3.0])
|
||||
assert obj.data == [1, "two", 3.0]
|
||||
|
||||
|
||||
class TestUnionTypes:
|
||||
def test_anyof_string_or_integer(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(value="hello").value == "hello"
|
||||
assert Model(value=42).value == 42
|
||||
|
||||
def test_oneof(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"oneOf": [{"type": "string"}, {"type": "number"}],
|
||||
},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(value="hello").value == "hello"
|
||||
assert Model(value=3.14).value == pytest.approx(3.14)
|
||||
|
||||
|
||||
class TestAllOfMerging:
|
||||
def test_allof_merges_properties(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"age": {"type": "integer"}},
|
||||
"required": ["age"],
|
||||
},
|
||||
],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(name="Alice", age=30)
|
||||
assert obj.name == "Alice"
|
||||
assert obj.age == 30
|
||||
|
||||
def test_single_allof(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["item"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(item={"id": 1})
|
||||
assert obj.item.id == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# $ref resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRefResolution:
|
||||
def test_ref_in_property(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"$ref": "#/$defs/Item"},
|
||||
},
|
||||
"required": ["item"],
|
||||
"$defs": {
|
||||
"Item": {
|
||||
"type": "object",
|
||||
"title": "Item",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(item={"name": "Widget"})
|
||||
assert obj.item.name == "Widget"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_name parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModelName:
|
||||
def test_model_name_override(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"title": "OriginalName",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, model_name="CustomSchema")
|
||||
assert Model.__name__ == "CustomSchema"
|
||||
|
||||
def test_model_name_fallback_to_title(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"title": "FromTitle",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model.__name__ == "FromTitle"
|
||||
|
||||
def test_model_name_fallback_to_dynamic(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model.__name__ == "DynamicModel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enrich_descriptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnrichDescriptions:
|
||||
def test_enriched_description_includes_constraints(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"description": "The score value",
|
||||
"minimum": 0,
|
||||
"maximum": 100,
|
||||
},
|
||||
},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=True)
|
||||
field_info = Model.model_fields["score"]
|
||||
assert "Minimum: 0" in field_info.description
|
||||
assert "Maximum: 100" in field_info.description
|
||||
assert "The score value" in field_info.description
|
||||
|
||||
def test_default_does_not_enrich(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"description": "The score value",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=False)
|
||||
field_info = Model.model_fields["score"]
|
||||
assert field_info.description == "The score value"
|
||||
|
||||
def test_enriched_description_propagates_to_nested(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level": {
|
||||
"type": "integer",
|
||||
"description": "Level",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
},
|
||||
},
|
||||
"required": ["level"],
|
||||
},
|
||||
},
|
||||
"required": ["config"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=True)
|
||||
nested_model = Model.model_fields["config"].annotation
|
||||
nested_field = nested_model.model_fields["level"]
|
||||
assert "Minimum: 1" in nested_field.description
|
||||
assert "Maximum: 10" in nested_field.description
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_properties(self) -> None:
|
||||
schema = {"type": "object", "properties": {}, "required": []}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj is not None
|
||||
|
||||
def test_no_properties_key(self) -> None:
|
||||
schema = {"type": "object"}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj is not None
|
||||
|
||||
def test_unknown_type_raises(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"weird": {"type": "hyperspace"},
|
||||
},
|
||||
"required": ["weird"],
|
||||
}
|
||||
with pytest.raises(ValueError, match="Unsupported JSON schema type"):
|
||||
create_model_from_schema(schema)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_rich_field_description
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildRichFieldDescription:
|
||||
def test_description_only(self) -> None:
|
||||
assert build_rich_field_description({"description": "A name"}) == "A name"
|
||||
|
||||
def test_empty_schema(self) -> None:
|
||||
assert build_rich_field_description({}) == ""
|
||||
|
||||
def test_format(self) -> None:
|
||||
desc = build_rich_field_description({"format": "date-time"})
|
||||
assert "Format: date-time" in desc
|
||||
|
||||
def test_enum(self) -> None:
|
||||
desc = build_rich_field_description({"enum": ["a", "b"]})
|
||||
assert "Allowed values:" in desc
|
||||
assert "'a'" in desc
|
||||
assert "'b'" in desc
|
||||
|
||||
def test_pattern(self) -> None:
|
||||
desc = build_rich_field_description({"pattern": "^[a-z]+$"})
|
||||
assert "Pattern: ^[a-z]+$" in desc
|
||||
|
||||
def test_min_max(self) -> None:
|
||||
desc = build_rich_field_description({"minimum": 0, "maximum": 100})
|
||||
assert "Minimum: 0" in desc
|
||||
assert "Maximum: 100" in desc
|
||||
|
||||
def test_min_max_length(self) -> None:
|
||||
desc = build_rich_field_description({"minLength": 1, "maxLength": 255})
|
||||
assert "Min length: 1" in desc
|
||||
assert "Max length: 255" in desc
|
||||
|
||||
def test_examples(self) -> None:
|
||||
desc = build_rich_field_description({"examples": ["foo", "bar", "baz", "extra"]})
|
||||
assert "Examples:" in desc
|
||||
assert "'foo'" in desc
|
||||
assert "'baz'" in desc
|
||||
# Only first 3 shown
|
||||
assert "'extra'" not in desc
|
||||
|
||||
def test_combined_constraints(self) -> None:
|
||||
desc = build_rich_field_description({
|
||||
"description": "A score",
|
||||
"minimum": 0,
|
||||
"maximum": 10,
|
||||
"format": "int32",
|
||||
})
|
||||
assert desc.startswith("A score")
|
||||
assert "Minimum: 0" in desc
|
||||
assert "Maximum: 10" in desc
|
||||
assert "Format: int32" in desc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema transformation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRefs:
|
||||
def test_basic_ref_resolution(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"item": {"$ref": "#/$defs/Item"}},
|
||||
"$defs": {
|
||||
"Item": {"type": "object", "properties": {"id": {"type": "integer"}}},
|
||||
},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
assert "$ref" not in resolved["properties"]["item"]
|
||||
assert resolved["properties"]["item"]["type"] == "object"
|
||||
|
||||
def test_nested_ref_resolution(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"wrapper": {"$ref": "#/$defs/Wrapper"}},
|
||||
"$defs": {
|
||||
"Wrapper": {
|
||||
"type": "object",
|
||||
"properties": {"inner": {"$ref": "#/$defs/Inner"}},
|
||||
},
|
||||
"Inner": {"type": "string"},
|
||||
},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
wrapper = resolved["properties"]["wrapper"]
|
||||
assert wrapper["properties"]["inner"]["type"] == "string"
|
||||
|
||||
def test_missing_ref_raises(self) -> None:
|
||||
schema = {
|
||||
"properties": {"x": {"$ref": "#/$defs/Missing"}},
|
||||
"$defs": {},
|
||||
}
|
||||
with pytest.raises(KeyError, match="Missing"):
|
||||
resolve_refs(schema)
|
||||
|
||||
def test_no_refs_unchanged(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
assert resolved == schema
|
||||
|
||||
|
||||
class TestForceAdditionalPropertiesFalse:
|
||||
def test_adds_to_object(self) -> None:
|
||||
schema = {"type": "object", "properties": {"x": {"type": "integer"}}}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["additionalProperties"] is False
|
||||
|
||||
def test_adds_empty_properties_and_required(self) -> None:
|
||||
schema = {"type": "object"}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["properties"] == {}
|
||||
assert result["required"] == []
|
||||
|
||||
def test_recursive_nested_objects(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["additionalProperties"] is False
|
||||
assert result["properties"]["child"]["additionalProperties"] is False
|
||||
|
||||
def test_does_not_affect_non_objects(self) -> None:
|
||||
schema = {"type": "string"}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert "additionalProperties" not in result
|
||||
|
||||
|
||||
class TestStripUnsupportedFormats:
|
||||
def test_removes_email_format(self) -> None:
|
||||
schema = {"type": "string", "format": "email"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result
|
||||
|
||||
def test_keeps_date_time(self) -> None:
|
||||
schema = {"type": "string", "format": "date-time"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert result["format"] == "date-time"
|
||||
|
||||
def test_keeps_date(self) -> None:
|
||||
schema = {"type": "string", "format": "date"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert result["format"] == "date"
|
||||
|
||||
def test_removes_uri_format(self) -> None:
|
||||
schema = {"type": "string", "format": "uri"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result
|
||||
|
||||
def test_recursive(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"created": {"type": "string", "format": "date-time"},
|
||||
},
|
||||
}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result["properties"]["email"]
|
||||
assert result["properties"]["created"]["format"] == "date-time"
|
||||
|
||||
|
||||
class TestEnsureTypeInSchemas:
|
||||
def test_empty_schema_in_anyof_gets_type(self) -> None:
|
||||
schema = {"anyOf": [{}, {"type": "string"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result["anyOf"][0] == {"type": "object"}
|
||||
|
||||
def test_empty_schema_in_oneof_gets_type(self) -> None:
|
||||
schema = {"oneOf": [{}, {"type": "integer"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result["oneOf"][0] == {"type": "object"}
|
||||
|
||||
def test_non_empty_unchanged(self) -> None:
|
||||
schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result == schema
|
||||
|
||||
|
||||
class TestConvertOneofToAnyof:
|
||||
def test_converts_top_level(self) -> None:
|
||||
schema = {"oneOf": [{"type": "string"}, {"type": "integer"}]}
|
||||
result = convert_oneof_to_anyof(deepcopy(schema))
|
||||
assert "oneOf" not in result
|
||||
assert "anyOf" in result
|
||||
assert len(result["anyOf"]) == 2
|
||||
|
||||
def test_converts_nested(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"oneOf": [{"type": "string"}, {"type": "number"}]},
|
||||
},
|
||||
}
|
||||
result = convert_oneof_to_anyof(deepcopy(schema))
|
||||
assert "anyOf" in result["properties"]["value"]
|
||||
assert "oneOf" not in result["properties"]["value"]
|
||||
|
||||
|
||||
class TestEnsureAllPropertiesRequired:
|
||||
def test_makes_all_required(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "string"}, "b": {"type": "integer"}},
|
||||
"required": ["a"],
|
||||
}
|
||||
result = ensure_all_properties_required(deepcopy(schema))
|
||||
assert set(result["required"]) == {"a", "b"}
|
||||
|
||||
def test_recursive(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}, "y": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
result = ensure_all_properties_required(deepcopy(schema))
|
||||
assert set(result["properties"]["child"]["required"]) == {"x", "y"}
|
||||
|
||||
|
||||
class TestStripNullFromTypes:
|
||||
def test_strips_null_from_anyof(self) -> None:
|
||||
schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert "anyOf" not in result
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_strips_null_from_type_array(self) -> None:
|
||||
schema = {"type": ["string", "null"]}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_multiple_non_null_in_anyof(self) -> None:
|
||||
schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}],
|
||||
}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert len(result["anyOf"]) == 2
|
||||
|
||||
def test_no_null_unchanged(self) -> None:
|
||||
schema = {"type": "string"}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert result == schema
|
||||
|
||||
|
||||
class TestEndToEndMCPSchema:
|
||||
"""Realistic MCP tool schema exercising multiple features simultaneously."""
|
||||
|
||||
MCP_SCHEMA: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
"minLength": 1,
|
||||
"maxLength": 500,
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum results",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "csv", "xml"],
|
||||
"description": "Output format",
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date_from": {"type": "string", "format": "date"},
|
||||
"date_to": {"type": "string", "format": "date"},
|
||||
"categories": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["date_from"],
|
||||
},
|
||||
"sort_order": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
"required": ["query", "format", "filters"],
|
||||
}
|
||||
|
||||
def test_model_creation(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
assert Model is not None
|
||||
assert issubclass(Model, BaseModel)
|
||||
|
||||
def test_valid_input_accepted(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test search",
|
||||
format="json",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
)
|
||||
assert obj.query == "test search"
|
||||
assert obj.format == "json"
|
||||
|
||||
def test_invalid_enum_rejected(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
with pytest.raises(Exception):
|
||||
Model(
|
||||
query="test",
|
||||
format="yaml",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
)
|
||||
|
||||
def test_model_name_for_mcp_tool(self) -> None:
|
||||
Model = create_model_from_schema(
|
||||
self.MCP_SCHEMA, model_name="search_toolSchema"
|
||||
)
|
||||
assert Model.__name__ == "search_toolSchema"
|
||||
|
||||
def test_enriched_descriptions_for_mcp(self) -> None:
|
||||
Model = create_model_from_schema(
|
||||
self.MCP_SCHEMA, enrich_descriptions=True
|
||||
)
|
||||
query_field = Model.model_fields["query"]
|
||||
assert "Min length: 1" in query_field.description
|
||||
assert "Max length: 500" in query_field.description
|
||||
|
||||
max_results_field = Model.model_fields["max_results"]
|
||||
assert "Minimum: 1" in max_results_field.description
|
||||
assert "Maximum: 100" in max_results_field.description
|
||||
|
||||
format_field = Model.model_fields["format"]
|
||||
assert "Allowed values:" in format_field.description
|
||||
|
||||
def test_optional_fields_accept_none(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test",
|
||||
format="csv",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
max_results=None,
|
||||
sort_order=None,
|
||||
)
|
||||
assert obj.max_results is None
|
||||
assert obj.sort_order is None
|
||||
|
||||
def test_nested_filters_validated(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test",
|
||||
format="xml",
|
||||
filters={
|
||||
"date_from": "2025-01-01",
|
||||
"date_to": "2025-12-31",
|
||||
"categories": ["news", "tech"],
|
||||
},
|
||||
)
|
||||
assert obj.filters.date_from == datetime.date(2025, 1, 1)
|
||||
assert obj.filters.categories == ["news", "tech"]
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.10.0"
|
||||
__version__ = "1.9.3"
|
||||
|
||||
@@ -943,8 +943,6 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
)
|
||||
|
||||
if docs_files_staged:
|
||||
docs_branch = f"docs/changelog-v{version}"
|
||||
run_command(["git", "checkout", "-b", docs_branch])
|
||||
for f in docs_files_staged:
|
||||
run_command(["git", "add", f])
|
||||
run_command(
|
||||
@@ -956,69 +954,8 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Committed docs updates")
|
||||
|
||||
run_command(["git", "push", "-u", "origin", docs_branch])
|
||||
console.print(f"[green]✓[/green] Pushed branch {docs_branch}")
|
||||
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--title",
|
||||
f"docs: update changelog and version for v{version}",
|
||||
"--body",
|
||||
"",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Created docs PR")
|
||||
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"merge",
|
||||
docs_branch,
|
||||
"--squash",
|
||||
"--auto",
|
||||
"--delete-branch",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Enabled auto-merge on docs PR")
|
||||
|
||||
import time
|
||||
|
||||
console.print("[cyan]Waiting for PR checks to pass and merge...[/cyan]")
|
||||
while True:
|
||||
time.sleep(10)
|
||||
try:
|
||||
state = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"view",
|
||||
docs_branch,
|
||||
"--json",
|
||||
"state",
|
||||
"--jq",
|
||||
".state",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
state = ""
|
||||
|
||||
if state == "MERGED":
|
||||
break
|
||||
|
||||
console.print("[dim]Still waiting for PR to merge...[/dim]")
|
||||
|
||||
console.print("[green]✓[/green] Docs PR merged")
|
||||
|
||||
run_command(["git", "checkout", "main"])
|
||||
run_command(["git", "pull"])
|
||||
console.print("[green]✓[/green] main branch updated with docs changes")
|
||||
run_command(["git", "push"])
|
||||
console.print("[green]✓[/green] Pushed docs updates")
|
||||
else:
|
||||
for lang in changelog_langs:
|
||||
cl_path = cwd / "docs" / lang / "changelog.mdx"
|
||||
@@ -1034,9 +971,6 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
console.print(
|
||||
"[dim][DRY RUN][/dim] Skipping docs version (pre-release)"
|
||||
)
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would create branch docs/changelog-v{version}, PR, and merge"
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
with console.status(f"[cyan]Creating tag {tag_name}..."):
|
||||
|
||||
Reference in New Issue
Block a user