mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-06 19:18:16 +00:00
Compare commits
4 Commits
lorenze/im
...
docs/train
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08d210bb8 | ||
|
|
34100d290b | ||
|
|
46f8fa59c1 | ||
|
|
4867dced0e |
@@ -320,7 +320,6 @@
|
||||
"en/enterprise/guides/update-crew",
|
||||
"en/enterprise/guides/enable-crew-studio",
|
||||
"en/enterprise/guides/azure-openai-setup",
|
||||
"en/enterprise/guides/automation-triggers",
|
||||
"en/enterprise/guides/hubspot-trigger",
|
||||
"en/enterprise/guides/react-component-export",
|
||||
"en/enterprise/guides/salesforce-trigger",
|
||||
@@ -659,7 +658,6 @@
|
||||
"pt-BR/enterprise/guides/update-crew",
|
||||
"pt-BR/enterprise/guides/enable-crew-studio",
|
||||
"pt-BR/enterprise/guides/azure-openai-setup",
|
||||
"pt-BR/enterprise/guides/automation-triggers",
|
||||
"pt-BR/enterprise/guides/hubspot-trigger",
|
||||
"pt-BR/enterprise/guides/react-component-export",
|
||||
"pt-BR/enterprise/guides/salesforce-trigger",
|
||||
@@ -1009,7 +1007,6 @@
|
||||
"ko/enterprise/guides/update-crew",
|
||||
"ko/enterprise/guides/enable-crew-studio",
|
||||
"ko/enterprise/guides/azure-openai-setup",
|
||||
"ko/enterprise/guides/automation-triggers",
|
||||
"ko/enterprise/guides/hubspot-trigger",
|
||||
"ko/enterprise/guides/react-component-export",
|
||||
"ko/enterprise/guides/salesforce-trigger",
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: "Automation Triggers"
|
||||
description: "Automatically execute your CrewAI workflows when specific events occur in connected integrations"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Automation triggers enable you to automatically run your CrewAI deployments when specific events occur in your connected integrations, creating powerful event-driven workflows that respond to real-time changes in your business systems.
|
||||
|
||||
## Overview
|
||||
|
||||
With automation triggers, you can:
|
||||
|
||||
- **Respond to real-time events** - Automatically execute workflows when specific conditions are met
|
||||
- **Integrate with external systems** - Connect with platforms like Gmail, Outlook, OneDrive, JIRA, Slack, Stripe and more
|
||||
- **Scale your automation** - Handle high-volume events without manual intervention
|
||||
- **Maintain context** - Access trigger data within your crews and flows
|
||||
|
||||
## Managing Automation Triggers
|
||||
|
||||
### Viewing Available Triggers
|
||||
|
||||
To access and manage your automation triggers:
|
||||
|
||||
1. Navigate to your deployment in the CrewAI dashboard
|
||||
2. Click on the **Triggers** tab to view all available trigger integrations
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="List of available automation triggers" />
|
||||
</Frame>
|
||||
|
||||
This view shows all the trigger integrations available for your deployment, along with their current connection status.
|
||||
|
||||
### Enabling and Disabling Triggers
|
||||
|
||||
Each trigger can be easily enabled or disabled using the toggle switch:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Enable or disable triggers with toggle" />
|
||||
</Frame>
|
||||
|
||||
- **Enabled (blue toggle)**: The trigger is active and will automatically execute your deployment when the specified events occur
|
||||
- **Disabled (gray toggle)**: The trigger is inactive and will not respond to events
|
||||
|
||||
Simply click the toggle to change the trigger state. Changes take effect immediately.
|
||||
|
||||
### Monitoring Trigger Executions
|
||||
|
||||
Track the performance and history of your triggered executions:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Building Automation
|
||||
|
||||
Before building your automation, it's helpful to understand the structure of trigger payloads that your crews and flows will receive.
|
||||
|
||||
### Payload Samples Repository
|
||||
|
||||
We maintain a comprehensive repository with sample payloads from various trigger sources to help you build and test your automations:
|
||||
|
||||
**🔗 [CrewAI Enterprise Trigger Payload Samples](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
This repository contains:
|
||||
|
||||
- **Real payload examples** from different trigger sources (Gmail, Google Drive, etc.)
|
||||
- **Payload structure documentation** showing the format and available fields
|
||||
|
||||
### Triggers with Crew
|
||||
|
||||
Your existing crew definitions work seamlessly with triggers, you just need to have a task to parse the received payload:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
The crew will automatically receive and can access the trigger payload through the standard CrewAI context mechanisms.
|
||||
|
||||
### Integration with Flows
|
||||
|
||||
For flows, you have more control over how trigger data is handled:
|
||||
|
||||
#### Accessing Trigger Payload
|
||||
|
||||
All `@start()` methods in your flows will accept an additional parameter called `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
This start method can receive trigger data
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Process the trigger data
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Store in flow state for use by other methods
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# Handle manual execution
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
Process the data from the trigger
|
||||
"""
|
||||
# ... process the trigger
|
||||
```
|
||||
|
||||
#### Triggering Crews from Flows
|
||||
|
||||
When kicking off a crew within a flow that was triggered, pass the trigger payload as it:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delegate processing to a specialized crew
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# Pass the trigger payload to the crew
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Trigger not firing:**
|
||||
- Verify the trigger is enabled
|
||||
- Check integration connection status
|
||||
|
||||
**Execution failures:**
|
||||
- Check the execution logs for error details
|
||||
- If you are developing, make sure the inputs include the `crewai_trigger_payload` parameter with the correct payload
|
||||
|
||||
Automation triggers transform your CrewAI deployments into responsive, event-driven systems that can seamlessly integrate with your existing business processes and tools.
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 142 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 330 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 133 KiB |
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: "자동화 트리거"
|
||||
description: "연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 워크플로우를 자동으로 실행합니다"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
자동화 트리거를 사용하면 연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 배포를 자동으로 실행할 수 있어, 비즈니스 시스템의 실시간 변화에 반응하는 강력한 이벤트 기반 워크플로우를 만들 수 있습니다.
|
||||
|
||||
## 개요
|
||||
|
||||
자동화 트리거를 사용하면 다음을 수행할 수 있습니다:
|
||||
|
||||
- **실시간 이벤트에 응답** - 특정 조건이 충족될 때 워크플로우를 자동으로 실행
|
||||
- **외부 시스템과 통합** - Gmail, Outlook, OneDrive, JIRA, Slack, Stripe 등의 플랫폼과 연결
|
||||
- **자동화 확장** - 수동 개입 없이 대용량 이벤트 처리
|
||||
- **컨텍스트 유지** - crew와 flow 내에서 트리거 데이터에 액세스
|
||||
|
||||
## 자동화 트리거 관리
|
||||
|
||||
### 사용 가능한 트리거 보기
|
||||
|
||||
자동화 트리거에 액세스하고 관리하려면:
|
||||
|
||||
1. CrewAI 대시보드에서 배포로 이동
|
||||
2. **트리거** 탭을 클릭하여 사용 가능한 모든 트리거 통합 보기
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="사용 가능한 자동화 트리거 목록" />
|
||||
</Frame>
|
||||
|
||||
이 보기는 배포에 사용 가능한 모든 트리거 통합과 현재 연결 상태를 보여줍니다.
|
||||
|
||||
### 트리거 활성화 및 비활성화
|
||||
|
||||
각 트리거는 토글 스위치를 사용하여 쉽게 활성화하거나 비활성화할 수 있습니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="토글로 트리거 활성화 또는 비활성화" />
|
||||
</Frame>
|
||||
|
||||
- **활성화됨 (파란색 토글)**: 트리거가 활성 상태이며 지정된 이벤트가 발생할 때 배포를 자동으로 실행합니다
|
||||
- **비활성화됨 (회색 토글)**: 트리거가 비활성 상태이며 이벤트에 응답하지 않습니다
|
||||
|
||||
토글을 클릭하기만 하면 트리거 상태를 변경할 수 있습니다. 변경 사항은 즉시 적용됩니다.
|
||||
|
||||
### 트리거 실행 모니터링
|
||||
|
||||
트리거된 실행의 성능과 기록을 추적합니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="자동화에 의해 트리거된 실행 목록" />
|
||||
</Frame>
|
||||
|
||||
## 자동화 구축
|
||||
|
||||
자동화를 구축하기 전에 crew와 flow가 받을 트리거 페이로드의 구조를 이해하는 것이 도움이 됩니다.
|
||||
|
||||
### 페이로드 샘플 저장소
|
||||
|
||||
자동화를 구축하고 테스트하는 데 도움이 되도록 다양한 트리거 소스의 샘플 페이로드가 포함된 포괄적인 저장소를 유지 관리하고 있습니다:
|
||||
|
||||
**🔗 [CrewAI Enterprise 트리거 페이로드 샘플](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
이 저장소에는 다음이 포함되어 있습니다:
|
||||
|
||||
- **실제 페이로드 예제** - 다양한 트리거 소스(Gmail, Google Drive 등)에서 가져온 예제
|
||||
- **페이로드 구조 문서** - 형식과 사용 가능한 필드를 보여주는 문서
|
||||
|
||||
### Crew와 트리거
|
||||
|
||||
기존 crew 정의는 트리거와 완벽하게 작동하며, 받은 페이로드를 분석하는 작업만 있으면 됩니다:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
crew는 자동으로 트리거 페이로드를 받고 표준 CrewAI 컨텍스트 메커니즘을 통해 액세스할 수 있습니다.
|
||||
|
||||
### Flow와의 통합
|
||||
|
||||
flow의 경우 트리거 데이터 처리 방법을 더 세밀하게 제어할 수 있습니다:
|
||||
|
||||
#### 트리거 페이로드 액세스
|
||||
|
||||
flow의 모든 `@start()` 메서드는 `crewai_trigger_payload`라는 추가 매개변수를 허용합니다:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
이 start 메서드는 트리거 데이터를 받을 수 있습니다
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# 트리거 데이터 처리
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# 다른 메서드에서 사용할 수 있도록 flow 상태에 저장
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# 수동 실행 처리
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
트리거 데이터 처리
|
||||
"""
|
||||
# ... 트리거 처리
|
||||
```
|
||||
|
||||
#### Flow에서 Crew 트리거하기
|
||||
|
||||
트리거된 flow 내에서 crew를 시작할 때 트리거 페이로드를 전달합니다:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
전문 crew에 처리 위임
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# crew에 트리거 페이로드 전달
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
**트리거가 작동하지 않는 경우:**
|
||||
- 트리거가 활성화되어 있는지 확인
|
||||
- 통합 연결 상태 확인
|
||||
|
||||
**실행 실패:**
|
||||
- 오류 세부 정보는 실행 로그 확인
|
||||
- 개발 중인 경우 입력에 올바른 페이로드가 포함된 `crewai_trigger_payload` 매개변수가 포함되어 있는지 확인
|
||||
|
||||
자동화 트리거는 CrewAI 배포를 기존 비즈니스 프로세스 및 도구와 완벽하게 통합할 수 있는 반응형 이벤트 기반 시스템으로 변환합니다.
|
||||
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: "Triggers de Automação"
|
||||
description: "Execute automaticamente seus workflows CrewAI quando eventos específicos ocorrem em integrações conectadas"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Os triggers de automação permitem executar automaticamente suas implantações CrewAI quando eventos específicos ocorrem em suas integrações conectadas, criando workflows poderosos orientados por eventos que respondem a mudanças em tempo real em seus sistemas de negócio.
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Com triggers de automação, você pode:
|
||||
|
||||
- **Responder a eventos em tempo real** - Execute workflows automaticamente quando condições específicas forem atendidas
|
||||
- **Integrar com sistemas externos** - Conecte com plataformas como Gmail, Outlook, OneDrive, JIRA, Slack, Stripe e muito mais
|
||||
- **Escalar sua automação** - Lide com eventos de alto volume sem intervenção manual
|
||||
- **Manter contexto** - Acesse dados do trigger dentro de suas crews e flows
|
||||
|
||||
## Gerenciando Triggers de Automação
|
||||
|
||||
### Visualizando Triggers Disponíveis
|
||||
|
||||
Para acessar e gerenciar seus triggers de automação:
|
||||
|
||||
1. Navegue até sua implantação no painel do CrewAI
|
||||
2. Clique na aba **Triggers** para visualizar todas as integrações de trigger disponíveis
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="Lista de triggers de automação disponíveis" />
|
||||
</Frame>
|
||||
|
||||
Esta visualização mostra todas as integrações de trigger disponíveis para sua implantação, junto com seus status de conexão atuais.
|
||||
|
||||
### Habilitando e Desabilitando Triggers
|
||||
|
||||
Cada trigger pode ser facilmente habilitado ou desabilitado usando o botão de alternância:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Habilitar ou desabilitar triggers com alternância" />
|
||||
</Frame>
|
||||
|
||||
- **Habilitado (alternância azul)**: O trigger está ativo e executará automaticamente sua implantação quando os eventos especificados ocorrerem
|
||||
- **Desabilitado (alternância cinza)**: O trigger está inativo e não responderá a eventos
|
||||
|
||||
Simplesmente clique na alternância para mudar o estado do trigger. As alterações entram em vigor imediatamente.
|
||||
|
||||
### Monitorando Execuções de Trigger
|
||||
|
||||
Acompanhe o desempenho e histórico de suas execuções acionadas:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="Lista de execuções acionadas por automação" />
|
||||
</Frame>
|
||||
|
||||
## Construindo Automação
|
||||
|
||||
Antes de construir sua automação, é útil entender a estrutura dos payloads de trigger que suas crews e flows receberão.
|
||||
|
||||
### Repositório de Amostras de Payload
|
||||
|
||||
Mantemos um repositório abrangente com amostras de payload de várias fontes de trigger para ajudá-lo a construir e testar suas automações:
|
||||
|
||||
**🔗 [Amostras de Payload de Trigger CrewAI Enterprise](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
Este repositório contém:
|
||||
|
||||
- **Exemplos reais de payload** de diferentes fontes de trigger (Gmail, Google Drive, etc.)
|
||||
- **Documentação da estrutura de payload** mostrando o formato e campos disponíveis
|
||||
|
||||
### Triggers com Crew
|
||||
|
||||
Suas definições de crew existentes funcionam perfeitamente com triggers, você só precisa ter uma tarefa para analisar o payload recebido:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MinhaCrewAutomatizada:
|
||||
@agent
|
||||
def pesquisador(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['pesquisador'],
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_payload_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_payload_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_conteudo_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_dados_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
```
|
||||
|
||||
A crew receberá automaticamente e pode acessar o payload do trigger através dos mecanismos de contexto padrão do CrewAI.
|
||||
|
||||
### Integração com Flows
|
||||
|
||||
Para flows, você tem mais controle sobre como os dados do trigger são tratados:
|
||||
|
||||
#### Acessando Payload do Trigger
|
||||
|
||||
Todos os métodos `@start()` em seus flows aceitarão um parâmetro adicional chamado `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MeuFlowAutomatizado(Flow):
|
||||
@start()
|
||||
def lidar_com_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Este método start pode receber dados do trigger
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Processa os dados do trigger
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
dados_evento = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Armazena no estado do flow para uso por outros métodos
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = dados_evento
|
||||
|
||||
return dados_evento
|
||||
|
||||
# Lida com execução manual
|
||||
return None
|
||||
|
||||
@listen(lidar_com_trigger)
|
||||
def processar_dados(self, dados_trigger):
|
||||
"""
|
||||
Processa os dados do trigger
|
||||
"""
|
||||
# ... processa o trigger
|
||||
```
|
||||
|
||||
#### Acionando Crews a partir de Flows
|
||||
|
||||
Ao iniciar uma crew dentro de um flow que foi acionado, passe o payload do trigger conforme ele:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegar_para_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delega processamento para uma crew especializada
|
||||
"""
|
||||
crew = MinhaCrewEspecializada()
|
||||
|
||||
# Passa o payload do trigger para a crew
|
||||
resultado = crew.crew().kickoff(
|
||||
inputs={
|
||||
'parametro_personalizado': "valor_personalizado",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return resultado
|
||||
```
|
||||
|
||||
## Solução de Problemas
|
||||
|
||||
**Trigger não está sendo disparado:**
|
||||
- Verifique se o trigger está habilitado
|
||||
- Verifique o status de conexão da integração
|
||||
|
||||
**Falhas de execução:**
|
||||
- Verifique os logs de execução para detalhes do erro
|
||||
- Se você está desenvolvendo, certifique-se de que as entradas incluem o parâmetro `crewai_trigger_payload` com o payload correto
|
||||
|
||||
Os triggers de automação transformam suas implantações CrewAI em sistemas responsivos orientados por eventos que podem se integrar perfeitamente com seus processos de negócio e ferramentas existentes.
|
||||
@@ -1,18 +1,7 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -173,7 +162,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
@@ -287,7 +276,7 @@ class Agent(BaseAgent):
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
@@ -347,6 +336,7 @@ class Agent(BaseAgent):
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
)
|
||||
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -8,13 +8,13 @@ from .cache.cache_handler import CacheHandler
|
||||
class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: Optional[ToolCalling] = None
|
||||
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
|
||||
cache: Optional[CacheHandler]
|
||||
|
||||
def __init__(self, cache: Optional[CacheHandler] = None):
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
self.last_used_tool = None
|
||||
self.last_used_tool = {} # type: ignore # BUG?: same as above
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
|
||||
@@ -474,7 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -830,9 +829,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
@@ -884,9 +880,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Clear the resumption flag after initial execution completes
|
||||
self._is_execution_resuming = False
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -923,13 +916,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Automatically injects crewai_trigger_payload if available in flow inputs
|
||||
"""
|
||||
if start_method_name in self._completed_methods:
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(start_method_name)
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
|
||||
method = self._methods[start_method_name]
|
||||
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
||||
@@ -1061,15 +1050,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for router_name in routers_triggered:
|
||||
await self._execute_single_listener(router_name, result)
|
||||
# After executing router, the router's result is the path
|
||||
router_result = (
|
||||
self._method_outputs[-1] if self._method_outputs else None
|
||||
)
|
||||
router_result = self._method_outputs[-1]
|
||||
if router_result: # Only add non-None results
|
||||
router_results.append(router_result)
|
||||
current_trigger = (
|
||||
str(router_result)
|
||||
if router_result is not None
|
||||
else "" # Update for next iteration of router chain
|
||||
router_result # Update for next iteration of router chain
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
@@ -1087,24 +1072,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
# Find start methods triggered by this router result
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
# to allow cyclic execution
|
||||
was_resuming = self._is_execution_resuming
|
||||
self._is_execution_resuming = False
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
@@ -1142,9 +1109,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if router_only != is_router:
|
||||
continue
|
||||
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
@@ -1194,13 +1158,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
"""
|
||||
if listener_name in self._completed_methods:
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
await self._execute_listeners(listener_name, None)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(listener_name)
|
||||
# TODO: greyson fix
|
||||
# if listener_name in self._completed_methods:
|
||||
# await self._execute_listeners(listener_name, None)
|
||||
# return
|
||||
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
|
||||
@@ -316,143 +316,6 @@ class LLM(BaseLLM):
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Check for provider prefixes and route to native implementations
|
||||
if "/" in model:
|
||||
provider, actual_model = model.split("/", 1)
|
||||
|
||||
# Route to OpenAI native implementation
|
||||
if provider.lower() == "openai":
|
||||
try:
|
||||
from crewai.llms.openai import OpenAILLM
|
||||
|
||||
# Create native OpenAI instance with all the same parameters
|
||||
native_llm = OpenAILLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native OpenAI implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Route to Claude native implementation
|
||||
elif provider.lower() == "anthropic":
|
||||
try:
|
||||
from crewai.llms.anthropic import ClaudeLLM
|
||||
|
||||
# Create native Claude instance with all the same parameters
|
||||
native_llm = ClaudeLLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native Claude implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Route to Gemini native implementation
|
||||
elif provider.lower() == "google":
|
||||
try:
|
||||
from crewai.llms.google import GeminiLLM
|
||||
|
||||
# Create native Gemini instance with all the same parameters
|
||||
native_llm = GeminiLLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native Gemini implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Continue with original LiteLLM initialization
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -1276,11 +1139,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
# Ollama doesn't supports last message to be 'assistant'
|
||||
if (
|
||||
"ollama" in self.model.lower()
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
|
||||
@@ -1,11 +1 @@
|
||||
"""CrewAI LLM implementations."""
|
||||
|
||||
from .base_llm import BaseLLM
|
||||
from .openai import OpenAILLM
|
||||
from .anthropic import ClaudeLLM
|
||||
from .google import GeminiLLM
|
||||
|
||||
# Import the main LLM class for backward compatibility
|
||||
|
||||
|
||||
__all__ = ["BaseLLM", "OpenAILLM", "ClaudeLLM", "GeminiLLM"]
|
||||
"""LLM implementations for crewAI."""
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Anthropic Claude LLM implementation for CrewAI."""
|
||||
|
||||
from .claude import ClaudeLLM
|
||||
|
||||
__all__ = ["ClaudeLLM"]
|
||||
@@ -1,569 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal
|
||||
from anthropic import Anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ClaudeLLM(BaseLLM):
|
||||
"""Anthropic Claude LLM implementation with full LLM class compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None, # Not supported by Claude but kept for compatibility
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
frequency_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
logit_bias: Optional[
|
||||
Dict[int, float]
|
||||
] = None, # Not supported but kept for compatibility
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None, # Not supported but kept for compatibility
|
||||
logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None, # Not used by Anthropic
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[
|
||||
Literal["none", "low", "medium", "high"]
|
||||
] = None, # Not used by Claude
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
# Claude-specific parameters
|
||||
thinking_mode: bool = False, # Enable Claude's thinking mode
|
||||
top_k: Optional[int] = None, # Claude-specific sampling parameter
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Claude LLM with full compatibility.
|
||||
|
||||
Args:
|
||||
model: Claude model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-1 for Claude)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions (not supported by Claude, kept for compatibility)
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Not supported by Claude, kept for compatibility
|
||||
frequency_penalty: Not supported by Claude, kept for compatibility
|
||||
logit_bias: Not supported by Claude, kept for compatibility
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Not supported by Claude, kept for compatibility
|
||||
logprobs: Not supported by Claude, kept for compatibility
|
||||
top_logprobs: Not supported by Claude, kept for compatibility
|
||||
base_url: Custom API base URL
|
||||
api_base: Legacy API base parameter
|
||||
api_version: Not used by Anthropic
|
||||
api_key: Anthropic API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Not used by Claude, kept for compatibility
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
thinking_mode: Enable Claude's thinking mode (if supported)
|
||||
top_k: Claude-specific top-k sampling parameter
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n # Claude doesn't support n>1, but we store it for compatibility
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base or base_url
|
||||
self.base_url = base_url or api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
|
||||
# Claude-specific parameters
|
||||
self.thinking_mode = thinking_mode
|
||||
self.top_k = top_k
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize Anthropic client
|
||||
client_kwargs = {}
|
||||
if self.api_key:
|
||||
client_kwargs["api_key"] = self.api_key
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
if self.timeout:
|
||||
client_kwargs["timeout"] = self.timeout
|
||||
if max_retries:
|
||||
client_kwargs["max_retries"] = max_retries
|
||||
|
||||
# Add any additional kwargs that might be relevant to the client
|
||||
for key, value in kwargs.items():
|
||||
if key not in ["thinking_mode", "top_k"]: # Exclude our custom params
|
||||
client_kwargs[key] = value
|
||||
|
||||
self.client = Anthropic(**client_kwargs)
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration for Claude models."""
|
||||
# Claude model configurations based on Anthropic's documentation
|
||||
model_configs = {
|
||||
# Claude 3.5 Sonnet
|
||||
"claude-3-5-sonnet-20241022": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"claude-3-5-sonnet-20240620": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3.5 Haiku
|
||||
"claude-3-5-haiku-20241022": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Opus
|
||||
"claude-3-opus-20240229": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Sonnet
|
||||
"claude-3-sonnet-20240229": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Haiku
|
||||
"claude-3-haiku-20240307": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 2.1
|
||||
"claude-2.1": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
"claude-2": {
|
||||
"context_window": 100000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
# Claude Instant
|
||||
"claude-instant-1.2": {
|
||||
"context_window": 100000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": False,
|
||||
}
|
||||
|
||||
# Try exact match first
|
||||
if self.model in model_configs:
|
||||
return model_configs[self.model]
|
||||
|
||||
# Try prefix match for versioned models
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages for Anthropic API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted message dicts
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Claude requires alternating user/assistant messages and cannot start with system
|
||||
formatted_messages = []
|
||||
system_message = None
|
||||
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
# Store system message separately - Claude handles it differently
|
||||
if system_message is None:
|
||||
system_message = msg["content"]
|
||||
else:
|
||||
system_message += "\n\n" + msg["content"]
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
# Ensure messages alternate and start with user
|
||||
if formatted_messages and formatted_messages[0]["role"] != "user":
|
||||
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
|
||||
|
||||
# Store system message for later use
|
||||
self._system_message = system_message
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
|
||||
"""Format tools for Claude function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Claude-formatted tool definitions
|
||||
"""
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to Claude tool format
|
||||
formatted_tool = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("parameters", {}),
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from Claude response.
|
||||
|
||||
Args:
|
||||
response: Claude API response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
# Claude returns tool use in content blocks
|
||||
if not hasattr(response, "content") or not available_functions:
|
||||
return response.content[0].text if response.content else ""
|
||||
|
||||
# Look for tool use blocks
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "tool_use":
|
||||
function_name = content_block.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Claude provides arguments as a dict
|
||||
function_args = content_block.input
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
# If no tool calls, return text content
|
||||
return response.content[0].text if response.content else ""
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call Claude API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native claude", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
system_message = getattr(self, "_system_message", None)
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"max_tokens": self.max_tokens or 4000, # Claude requires max_tokens
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
api_params["system"] = system_message
|
||||
|
||||
# Add optional parameters that Claude supports
|
||||
if self.temperature is not None:
|
||||
api_params["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
api_params["top_p"] = self.top_p
|
||||
|
||||
if self.top_k is not None:
|
||||
api_params["top_k"] = self.top_k
|
||||
|
||||
if self.stop:
|
||||
api_params["stop_sequences"] = self.stop
|
||||
|
||||
# Add tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
if formatted_tools:
|
||||
api_params["tools"] = formatted_tools
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
response = self.client.messages.create(stream=True, **api_params)
|
||||
# Handle streaming (simplified implementation)
|
||||
full_response = ""
|
||||
try:
|
||||
for event in response:
|
||||
if hasattr(event, "type"):
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event, "delta") and hasattr(
|
||||
event.delta, "text"
|
||||
):
|
||||
full_response += event.delta.text
|
||||
except Exception as e:
|
||||
# If streaming fails, fall back to the response we have
|
||||
print(f"Streaming error (continuing with partial response): {e}")
|
||||
result = full_response or "No response content"
|
||||
else:
|
||||
response = self.client.messages.create(**api_params)
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": formatted_messages,
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
failed_event_data = {
|
||||
"error": str(e),
|
||||
}
|
||||
if from_task is not None:
|
||||
failed_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
failed_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(**failed_event_data),
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Claude API call failed: {str(e)}") from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if Claude models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 200000)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
|
||||
def supports_vision(self) -> bool:
|
||||
"""Check if the current model supports vision capabilities."""
|
||||
return self.model_config.get("supports_vision", False)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Google Gemini LLM implementation for CrewAI."""
|
||||
|
||||
from .gemini import GeminiLLM
|
||||
|
||||
__all__ = ["GeminiLLM"]
|
||||
@@ -1,737 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal, TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
types = None
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""Google Gemini LLM implementation using the official Google Gen AI Python SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-1.5-pro",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None, # Not supported by Gemini but kept for compatibility
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
frequency_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
logit_bias: Optional[
|
||||
Dict[int, float]
|
||||
] = None, # Not supported but kept for compatibility
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None, # Not supported but kept for compatibility
|
||||
logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
base_url: Optional[str] = None, # Not used by Gemini
|
||||
api_base: Optional[str] = None, # Not used by Gemini
|
||||
api_version: Optional[str] = None, # Not used by Gemini
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[
|
||||
Literal["none", "low", "medium", "high"]
|
||||
] = None, # Not used by Gemini
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
# Gemini-specific parameters
|
||||
top_k: Optional[int] = None, # Gemini top-k sampling parameter
|
||||
candidate_count: int = 1, # Number of response candidates
|
||||
safety_settings: Optional[
|
||||
List[Dict[str, Any]]
|
||||
] = None, # Gemini safety settings
|
||||
generation_config: Optional[
|
||||
Dict[str, Any]
|
||||
] = None, # Additional generation config
|
||||
# Vertex AI parameters
|
||||
use_vertex_ai: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
location: str = "us-central1",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Gemini LLM with the official Google Gen AI SDK.
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-1.5-pro', 'gemini-2.0-flash-001')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-2 for Gemini)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions (not supported by Gemini, kept for compatibility)
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Not supported by Gemini, kept for compatibility
|
||||
frequency_penalty: Not supported by Gemini, kept for compatibility
|
||||
logit_bias: Not supported by Gemini, kept for compatibility
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Not supported by Gemini, kept for compatibility
|
||||
logprobs: Not supported by Gemini, kept for compatibility
|
||||
top_logprobs: Not supported by Gemini, kept for compatibility
|
||||
base_url: Not used by Gemini
|
||||
api_base: Not used by Gemini
|
||||
api_version: Not used by Gemini
|
||||
api_key: Google AI API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Not used by Gemini, kept for compatibility
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
top_k: Gemini-specific top-k sampling parameter
|
||||
candidate_count: Number of response candidates to generate
|
||||
safety_settings: Gemini safety settings configuration
|
||||
generation_config: Additional Gemini generation configuration
|
||||
use_vertex_ai: Whether to use Vertex AI instead of Gemini API
|
||||
project_id: Google Cloud project ID (required for Vertex AI)
|
||||
location: Google Cloud region (default: us-central1)
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
# Check if Google Gen AI SDK is available
|
||||
if genai is None or types is None:
|
||||
raise ImportError(
|
||||
"Google Gen AI Python SDK is required. Please install it with: "
|
||||
"pip install google-genai"
|
||||
)
|
||||
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Gemini-specific parameters
|
||||
self.top_k = top_k
|
||||
self.candidate_count = candidate_count
|
||||
self.safety_settings = safety_settings or []
|
||||
self.generation_config = generation_config or {}
|
||||
|
||||
# Vertex AI parameters
|
||||
self.use_vertex_ai = use_vertex_ai
|
||||
self.project_id = project_id or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location
|
||||
|
||||
# API key handling
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.getenv("GOOGLE_AI_API_KEY")
|
||||
or os.getenv("GEMINI_API_KEY")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
)
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize client attribute
|
||||
self.client: Any = None
|
||||
|
||||
# Initialize the Google Gen AI client
|
||||
self._initialize_client()
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the Google Gen AI client."""
|
||||
if genai is None or types is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.use_vertex_ai:
|
||||
if not self.project_id:
|
||||
raise ValueError(
|
||||
"project_id is required when use_vertex_ai=True. "
|
||||
"Set it directly or via GOOGLE_CLOUD_PROJECT environment variable."
|
||||
)
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self.project_id,
|
||||
location=self.location,
|
||||
)
|
||||
else:
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"API key is required for Gemini Developer API. "
|
||||
"Set it via api_key parameter or GOOGLE_AI_API_KEY/GEMINI_API_KEY environment variable."
|
||||
)
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize Google Gen AI client: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration for Gemini models."""
|
||||
# Gemini model configurations based on Google's documentation
|
||||
model_configs = {
|
||||
# Gemini 2.0 Flash (latest)
|
||||
"gemini-2.0-flash": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.0-flash-001": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.0-flash-exp": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini 1.5 Pro
|
||||
"gemini-1.5-pro": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-002": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-001": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-exp-0827": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini 1.5 Flash
|
||||
"gemini-1.5-flash": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-002": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-001": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-8b": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-8b-exp-0827": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Legacy Gemini Pro
|
||||
"gemini-pro": {
|
||||
"context_window": 30720,
|
||||
"supports_tools": True,
|
||||
"supports_vision": False,
|
||||
},
|
||||
"gemini-pro-vision": {
|
||||
"context_window": 16384,
|
||||
"supports_tools": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini Ultra (when available)
|
||||
"gemini-ultra": {
|
||||
"context_window": 30720,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
}
|
||||
|
||||
# Try exact match first
|
||||
if self.model in model_configs:
|
||||
return model_configs[self.model]
|
||||
|
||||
# Try prefix match for versioned models
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(self, messages: Union[str, List[Dict[str, str]]]) -> List[Any]:
|
||||
"""Format messages for Google Gen AI SDK.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted Content objects
|
||||
"""
|
||||
if genai is None or types is None:
|
||||
return []
|
||||
|
||||
if isinstance(messages, str):
|
||||
return [
|
||||
types.Content(role="user", parts=[types.Part.from_text(text=messages)])
|
||||
]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Convert to Google Gen AI SDK format
|
||||
formatted_messages = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
# System instruction will be handled separately
|
||||
system_instruction = content
|
||||
elif role == "user":
|
||||
formatted_messages.append(
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted_messages.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
|
||||
# Store system instruction for later use
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[Any]]:
|
||||
"""Format tools for Google Gen AI SDK function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Google Gen AI SDK formatted tool definitions
|
||||
"""
|
||||
if genai is None or types is None:
|
||||
return None
|
||||
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to Google Gen AI SDK function declaration format
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
name=tool.get("name", ""),
|
||||
description=tool.get("description", ""),
|
||||
parameters=tool.get("parameters", {}),
|
||||
)
|
||||
formatted_tools.append(
|
||||
types.Tool(function_declarations=[function_declaration])
|
||||
)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _build_generation_config(
|
||||
self,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
) -> Any:
|
||||
"""Build Google Gen AI SDK generation config from parameters."""
|
||||
if genai is None or types is None:
|
||||
return {}
|
||||
config_dict = self.generation_config.copy()
|
||||
|
||||
# Add parameters that map to Gemini's generation config
|
||||
if self.temperature is not None:
|
||||
config_dict["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
config_dict["top_p"] = self.top_p
|
||||
|
||||
if self.top_k is not None:
|
||||
config_dict["top_k"] = self.top_k
|
||||
|
||||
if self.max_tokens is not None:
|
||||
config_dict["max_output_tokens"] = self.max_tokens
|
||||
|
||||
if self.candidate_count is not None:
|
||||
config_dict["candidate_count"] = self.candidate_count
|
||||
|
||||
if self.stop:
|
||||
config_dict["stop_sequences"] = self.stop
|
||||
|
||||
if self.stream:
|
||||
config_dict["stream"] = True
|
||||
|
||||
# Add safety settings
|
||||
if self.safety_settings:
|
||||
config_dict["safety_settings"] = self.safety_settings
|
||||
|
||||
# Add response format if specified
|
||||
if self.response_format:
|
||||
config_dict["response_modalities"] = ["TEXT"]
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
config_dict["system_instruction"] = system_instruction
|
||||
|
||||
# Add tools if present
|
||||
if tools:
|
||||
config_dict["tools"] = tools
|
||||
|
||||
return types.GenerateContentConfig(**config_dict)
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from Google Gen AI SDK response.
|
||||
|
||||
Args:
|
||||
response: Google Gen AI SDK response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
# Check if response has function calls
|
||||
if (
|
||||
not available_functions
|
||||
or not hasattr(response, "candidates")
|
||||
or not response.candidates
|
||||
):
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
candidate = response.candidates[0] if response.candidates else None
|
||||
if (
|
||||
not candidate
|
||||
or not hasattr(candidate, "content")
|
||||
or not hasattr(candidate.content, "parts")
|
||||
):
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
# Look for function call parts
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call"):
|
||||
function_call = part.function_call
|
||||
function_name = function_call.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Google Gen AI SDK provides arguments as a struct
|
||||
function_args = (
|
||||
dict(function_call.args)
|
||||
if hasattr(function_call, "args")
|
||||
else {}
|
||||
)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
# If no function calls, return text content
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call Google Gen AI SDK with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native gemini", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
system_instruction = getattr(self, "_system_instruction", None)
|
||||
|
||||
# Format tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
|
||||
# Build generation config
|
||||
generation_config = self._build_generation_config(
|
||||
system_instruction, formatted_tools
|
||||
)
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Prepare the API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": formatted_messages,
|
||||
"config": generation_config,
|
||||
}
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
# Streaming response
|
||||
response_stream = self.client.models.generate_content(**api_params)
|
||||
|
||||
full_response = ""
|
||||
try:
|
||||
for chunk in response_stream:
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Streaming error (continuing with partial response): {e}"
|
||||
)
|
||||
|
||||
result = full_response or "No response content"
|
||||
else:
|
||||
# Non-streaming response
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": messages, # Use original messages, not formatted_messages
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
print(
|
||||
f"Gemini API call failed (attempt {retry_count}/{self.max_retries + 1}): {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Gemini API call failed after {self.max_retries + 1} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if Gemini models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 1048576)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
|
||||
def supports_vision(self) -> bool:
|
||||
"""Check if the current model supports vision capabilities."""
|
||||
return self.model_config.get("supports_vision", False)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""OpenAI LLM implementation for CrewAI."""
|
||||
|
||||
from .chat import OpenAILLM
|
||||
|
||||
__all__ = ["OpenAILLM"]
|
||||
@@ -1,529 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""OpenAI LLM implementation with full LLM class compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI LLM with full compatibility.
|
||||
|
||||
Args:
|
||||
model: OpenAI model name (e.g., 'gpt-4', 'gpt-3.5-turbo')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions to generate
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
logit_bias: Logit bias dictionary
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Random seed for deterministic output
|
||||
logprobs: Whether to return log probabilities
|
||||
top_logprobs: Number of most likely tokens to return
|
||||
base_url: Custom API base URL
|
||||
api_base: Legacy API base parameter
|
||||
api_version: API version (for Azure)
|
||||
api_key: OpenAI API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Reasoning effort for o1 models
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base or base_url
|
||||
self.base_url = base_url or api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize OpenAI client
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
max_retries=max_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration."""
|
||||
# Enhanced model configurations matching current LLM_CONTEXT_WINDOW_SIZES
|
||||
model_configs = {
|
||||
"gpt-4": {"context_window": 8192, "supports_tools": True},
|
||||
"gpt-4o": {"context_window": 128000, "supports_tools": True},
|
||||
"gpt-4o-mini": {"context_window": 200000, "supports_tools": True},
|
||||
"gpt-4-turbo": {"context_window": 128000, "supports_tools": True},
|
||||
"gpt-4.1": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-4.1-mini": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-4.1-nano": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-3.5-turbo": {"context_window": 16385, "supports_tools": True},
|
||||
"o1-preview": {"context_window": 128000, "supports_tools": False},
|
||||
"o1-mini": {"context_window": 128000, "supports_tools": False},
|
||||
"o3-mini": {"context_window": 200000, "supports_tools": False},
|
||||
"o4-mini": {"context_window": 200000, "supports_tools": False},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {"context_window": 4096, "supports_tools": True}
|
||||
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages for OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted message dicts
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
# Convert system messages to assistant messages for O1
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
return formatted_messages
|
||||
|
||||
return messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
|
||||
"""Format tools for OpenAI function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
OpenAI-formatted tool definitions
|
||||
"""
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to OpenAI tool format
|
||||
formatted_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("parameters", {}),
|
||||
},
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from OpenAI response.
|
||||
|
||||
Args:
|
||||
response: OpenAI API response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
message = response.choices[0].message
|
||||
|
||||
if not message.tool_calls or not available_functions:
|
||||
return message.content
|
||||
|
||||
# Execute the first tool call
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Parse function arguments
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = f"Error parsing function arguments: {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call OpenAI API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native openai", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if self.temperature is not None:
|
||||
api_params["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
api_params["top_p"] = self.top_p
|
||||
|
||||
if self.n is not None:
|
||||
api_params["n"] = self.n
|
||||
|
||||
if self.max_tokens is not None:
|
||||
api_params["max_tokens"] = self.max_tokens
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
api_params["presence_penalty"] = self.presence_penalty
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
api_params["frequency_penalty"] = self.frequency_penalty
|
||||
|
||||
if self.logit_bias is not None:
|
||||
api_params["logit_bias"] = self.logit_bias
|
||||
|
||||
if self.seed is not None:
|
||||
api_params["seed"] = self.seed
|
||||
|
||||
if self.logprobs is not None:
|
||||
api_params["logprobs"] = self.logprobs
|
||||
|
||||
if self.top_logprobs is not None:
|
||||
api_params["top_logprobs"] = self.top_logprobs
|
||||
|
||||
if self.stop:
|
||||
api_params["stop"] = self.stop
|
||||
|
||||
if self.response_format is not None:
|
||||
# Handle structured output for Pydantic models
|
||||
if hasattr(self.response_format, "model_json_schema"):
|
||||
api_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.response_format.__name__,
|
||||
"schema": self.response_format.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
else:
|
||||
api_params["response_format"] = self.response_format
|
||||
|
||||
if self.reasoning_effort is not None and "o1" in self.model:
|
||||
api_params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
# Add tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
if formatted_tools:
|
||||
api_params["tools"] = formatted_tools
|
||||
api_params["tool_choice"] = "auto"
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
response = self.client.chat.completions.create(
|
||||
stream=True, **api_params
|
||||
)
|
||||
# Handle streaming (simplified for now)
|
||||
full_response = ""
|
||||
for chunk in response:
|
||||
if (
|
||||
hasattr(chunk.choices[0].delta, "content")
|
||||
and chunk.choices[0].delta.content
|
||||
):
|
||||
full_response += chunk.choices[0].delta.content
|
||||
result = full_response
|
||||
else:
|
||||
response = self.client.chat.completions.create(**api_params)
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# If no tool calls, return text content
|
||||
if result == response.choices[0].message.content:
|
||||
result = response.choices[0].message.content or ""
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": formatted_messages,
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
failed_event_data = {
|
||||
"error": str(e),
|
||||
}
|
||||
if from_task is not None:
|
||||
failed_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
failed_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(**failed_event_data),
|
||||
)
|
||||
|
||||
raise RuntimeError(f"OpenAI API call failed: {str(e)}") from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if OpenAI models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 4096)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
@@ -1,556 +0,0 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionCreateParams,
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
_prepare_documents_for_chromadb,
|
||||
_process_query_results,
|
||||
)
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
"""ChromaDB implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for ChromaDB, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
client: ChromaDBClientType
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB.
|
||||
|
||||
Uses the client's default embedding function if none provided.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method create_collection() requires a ClientAPI. "
|
||||
"Use acreate_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB asynchronously.
|
||||
|
||||
Creates a new collection with the specified name and optional configuration.
|
||||
If an embedding function is not provided, uses the client's default embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.acreate_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
|
||||
"Use create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
await self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB Collection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> collection = client.get_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method get_or_create_collection() requires a ClientAPI. "
|
||||
"Use aget_or_create_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB AsyncCollection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... collection = await client.aget_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
|
||||
"Use get_or_create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return await self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method add_documents() requires a ClientAPI. "
|
||||
"Use aadd_documents() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
collection.add(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
|
||||
"Use add_documents() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
await collection.add(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method search() requires a ClientAPI. "
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method asearch() requires an AsyncClientAPI. "
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_documents")
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method delete_collection() requires a ClientAPI. "
|
||||
"Use adelete_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_documents")
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
|
||||
"Use delete_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset() # Removes ALL data from ChromaDB
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method reset() requires a ClientAPI. "
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset() # Removes ALL data from ChromaDB
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method areset() requires an AsyncClientAPI. "
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Type definitions specific to ChromaDB implementation."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class PreparedDocuments(NamedTuple):
|
||||
"""Prepared documents ready for ChromaDB insertion.
|
||||
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
texts: list[str]
|
||||
metadatas: list[Mapping[str, str | int | float | bool]]
|
||||
|
||||
|
||||
class ExtractedSearchParams(NamedTuple):
|
||||
"""Extracted search parameters for ChromaDB queries.
|
||||
|
||||
Attributes:
|
||||
collection_name: Name of the collection to search
|
||||
query: Search query text
|
||||
limit: Maximum number of results
|
||||
metadata_filter: Optional metadata filter
|
||||
score_threshold: Optional minimum similarity score
|
||||
where: Optional ChromaDB where clause
|
||||
where_document: Optional ChromaDB document filter
|
||||
include: Fields to include in results
|
||||
"""
|
||||
|
||||
collection_name: str
|
||||
query: str
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float | None
|
||||
where: Where | None
|
||||
where_document: WhereDocument | None
|
||||
include: Include
|
||||
|
||||
|
||||
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for creating a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionParams to include any additional
|
||||
parameters specific to ChromaDB collection creation.
|
||||
"""
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
|
||||
"""Parameters for searching a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionSearchParams to include ChromaDB-specific
|
||||
search parameters like where clauses and include options.
|
||||
"""
|
||||
|
||||
where: Where
|
||||
where_document: WhereDocument
|
||||
include: Include
|
||||
@@ -1,220 +0,0 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
ExtractedSearchParams,
|
||||
PreparedDocuments,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
||||
"""Type guard to check if the client is a synchronous ClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a ClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, ClientAPI)
|
||||
|
||||
|
||||
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncClientAPI)
|
||||
|
||||
|
||||
def _prepare_documents_for_chromadb(
|
||||
documents: list[BaseRecord],
|
||||
) -> PreparedDocuments:
|
||||
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
||||
|
||||
Args:
|
||||
documents: List of BaseRecord documents to prepare.
|
||||
|
||||
Returns:
|
||||
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
||||
"""
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
metadatas.append({})
|
||||
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
"""Extract search parameters from kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Keyword arguments containing search parameters.
|
||||
|
||||
Returns:
|
||||
ExtractedSearchParams with all extracted parameters.
|
||||
"""
|
||||
return ExtractedSearchParams(
|
||||
collection_name=kwargs["collection_name"],
|
||||
query=kwargs["query"],
|
||||
limit=kwargs.get("limit", 10),
|
||||
metadata_filter=kwargs.get("metadata_filter"),
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_distance_to_score(
|
||||
distance: float,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
) -> float:
|
||||
"""Convert ChromaDB distance to similarity score.
|
||||
|
||||
Notes:
|
||||
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
||||
|
||||
Args:
|
||||
distance: The distance value from ChromaDB.
|
||||
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
||||
|
||||
Returns:
|
||||
Similarity score in range [0, 1] where 1 is most similar.
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
score = 1.0 - 0.5 * distance
|
||||
return max(0.0, min(1.0, score))
|
||||
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||
|
||||
|
||||
def _convert_chromadb_results_to_search_results(
|
||||
results: QueryResult,
|
||||
include: Include,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
score_threshold: float | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Convert ChromaDB query results to SearchResult format.
|
||||
|
||||
Args:
|
||||
results: ChromaDB query results.
|
||||
include: List of fields that were included in the query.
|
||||
distance_metric: The distance metric used by the collection.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include]
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
documents_list = results.get("documents")
|
||||
documents = (
|
||||
documents_list[0] if documents_list and "documents" in include_strings else []
|
||||
)
|
||||
|
||||
metadatas_list = results.get("metadatas")
|
||||
metadatas = (
|
||||
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
||||
)
|
||||
|
||||
distances_list = results.get("distances")
|
||||
distances = (
|
||||
distances_list[0] if distances_list and "distances" in include_strings else []
|
||||
)
|
||||
|
||||
for i, doc_id in enumerate(ids):
|
||||
if not distances or i >= len(distances):
|
||||
continue
|
||||
|
||||
distance = distances[i]
|
||||
score = _convert_distance_to_score(
|
||||
distance=distance, distance_metric=distance_metric
|
||||
)
|
||||
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def _process_query_results(
|
||||
collection: Collection | AsyncCollection,
|
||||
results: QueryResult,
|
||||
params: ExtractedSearchParams,
|
||||
) -> list[SearchResult]:
|
||||
"""Process ChromaDB query results and convert to SearchResult format.
|
||||
|
||||
Args:
|
||||
collection: The ChromaDB collection (sync or async) that was queried.
|
||||
results: Raw query results from ChromaDB.
|
||||
params: The search parameters used for the query.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
|
||||
distance_metric = cast(
|
||||
Literal["l2", "cosine", "ip"],
|
||||
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
||||
)
|
||||
|
||||
return _convert_chromadb_results_to_search_results(
|
||||
results=results,
|
||||
include=params.include,
|
||||
distance_metric=distance_metric,
|
||||
score_threshold=params.score_threshold,
|
||||
)
|
||||
@@ -7,71 +7,39 @@ from crewai.llm import LLM, BaseLLM
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
prefer_native: Optional[bool] = None,
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
Now supports provider prefixes like 'openai/gpt-4' for native implementations.
|
||||
|
||||
Args:
|
||||
llm_value (str | BaseLLM | Any | None):
|
||||
- str: The model name (e.g., "gpt-4" or "openai/gpt-4").
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
prefer_native (bool | None):
|
||||
- True: Use native provider implementations when available
|
||||
- False: Always use LiteLLM implementation
|
||||
- None: Use environment variable CREWAI_PREFER_NATIVE_LLMS (default: True)
|
||||
- Note: Provider prefixes (openai/, anthropic/) override this setting
|
||||
|
||||
Returns:
|
||||
A BaseLLM instance if successful, or None if something fails.
|
||||
|
||||
Examples:
|
||||
create_llm("gpt-4") # Uses LiteLLM or native based on prefer_native
|
||||
create_llm("openai/gpt-4") # Always uses native OpenAI implementation
|
||||
create_llm("anthropic/claude-3-sonnet") # Future: native Anthropic
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
|
||||
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
|
||||
return llm_value
|
||||
|
||||
# 2) Determine if we should prefer native implementations (unless provider prefix is used)
|
||||
if prefer_native is None:
|
||||
prefer_native = os.getenv("CREWAI_PREFER_NATIVE_LLMS", "true").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
# 3) If llm_value is a string (model name)
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
# Provider prefix (openai/, anthropic/) always takes precedence
|
||||
if "/" in llm_value:
|
||||
created_llm = LLM(model=llm_value) # LLM class handles routing
|
||||
return created_llm
|
||||
|
||||
# Try native implementation first if preferred and no prefix
|
||||
if prefer_native:
|
||||
native_llm = _create_native_llm(llm_value)
|
||||
if native_llm:
|
||||
return native_llm
|
||||
|
||||
# Fallback to LiteLLM
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
return None
|
||||
|
||||
# 4) If llm_value is None, parse environment variables or use default
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
return _llm_via_environment_or_fallback(prefer_native)
|
||||
return _llm_via_environment_or_fallback()
|
||||
|
||||
# 5) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
try:
|
||||
# Extract attributes with explicit types
|
||||
model = (
|
||||
@@ -80,8 +48,6 @@ def create_llm(
|
||||
or getattr(llm_value, "deployment_name", None)
|
||||
or str(llm_value)
|
||||
)
|
||||
|
||||
# Extract other parameters
|
||||
temperature: Optional[float] = getattr(llm_value, "temperature", None)
|
||||
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
|
||||
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
|
||||
@@ -90,7 +56,6 @@ def create_llm(
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
# Use LLM class constructor which handles routing
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -107,94 +72,9 @@ def create_llm(
|
||||
return None
|
||||
|
||||
|
||||
def _create_native_llm(model: str, **kwargs) -> Optional[BaseLLM]:
|
||||
"""
|
||||
Create a native LLM implementation based on the model name.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-4', 'claude-3-sonnet')
|
||||
**kwargs: Additional parameters for the LLM
|
||||
|
||||
Returns:
|
||||
Native LLM instance if supported, None otherwise
|
||||
"""
|
||||
try:
|
||||
# OpenAI models
|
||||
if _is_openai_model(model):
|
||||
from crewai.llms.openai import OpenAILLM
|
||||
|
||||
return OpenAILLM(model=model, **kwargs)
|
||||
|
||||
# Claude models
|
||||
if _is_claude_model(model):
|
||||
from crewai.llms.anthropic import ClaudeLLM
|
||||
|
||||
return ClaudeLLM(model=model, **kwargs)
|
||||
|
||||
# Gemini models
|
||||
if _is_gemini_model(model):
|
||||
from crewai.llms.google import GeminiLLM
|
||||
|
||||
return GeminiLLM(model=model, **kwargs)
|
||||
|
||||
# No native implementation found
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to create native LLM for model '{model}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _is_openai_model(model: str) -> bool:
|
||||
"""Check if a model is from OpenAI."""
|
||||
openai_prefixes = (
|
||||
"gpt-",
|
||||
"text-davinci",
|
||||
"text-curie",
|
||||
"text-babbage",
|
||||
"text-ada",
|
||||
"davinci",
|
||||
"curie",
|
||||
"babbage",
|
||||
"ada",
|
||||
"o1-",
|
||||
"o3-",
|
||||
"o4-",
|
||||
"chatgpt-",
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in openai_prefixes)
|
||||
|
||||
|
||||
def _is_claude_model(model: str) -> bool:
|
||||
"""Check if a model is from Anthropic (Claude)."""
|
||||
claude_prefixes = (
|
||||
"claude-",
|
||||
"claude", # For cases like just "claude"
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in claude_prefixes)
|
||||
|
||||
|
||||
def _is_gemini_model(model: str) -> bool:
|
||||
"""Check if a model is from Google (Gemini)."""
|
||||
gemini_prefixes = (
|
||||
"gemini-",
|
||||
"gemini", # For cases like just "gemini"
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in gemini_prefixes)
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback(
|
||||
prefer_native: bool = True,
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
"""
|
||||
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||
Now with native provider support.
|
||||
"""
|
||||
model_name = (
|
||||
os.environ.get("MODEL")
|
||||
@@ -203,13 +83,7 @@ def _llm_via_environment_or_fallback(
|
||||
or DEFAULT_LLM_MODEL
|
||||
)
|
||||
|
||||
# Try native implementation first if preferred
|
||||
if prefer_native:
|
||||
native_llm = _create_native_llm(model_name)
|
||||
if native_llm:
|
||||
return native_llm
|
||||
|
||||
# Initialize parameters with correct types (original logic continues)
|
||||
# Initialize parameters with correct types
|
||||
model: str = model_name
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
@@ -15,37 +15,37 @@ def mock_llm_responses():
|
||||
"ready": "I'll solve this simple math problem.\n\nREADY: I am ready to execute the task.\n\n",
|
||||
"not_ready": "I need to think about derivatives.\n\nNOT READY: I need to refine my plan because I'm not sure about the derivative rules.",
|
||||
"ready_after_refine": "I'll use the power rule for derivatives where d/dx(x^n) = n*x^(n-1).\n\nREADY: I am ready to execute the task.",
|
||||
"execution": "4",
|
||||
"execution": "4"
|
||||
}
|
||||
|
||||
|
||||
def test_agent_with_reasoning(mock_llm_responses):
|
||||
"""Test agent with reasoning."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
agent.llm.call = lambda messages, *args, **kwargs: (
|
||||
mock_llm_responses["ready"]
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages)
|
||||
else mock_llm_responses["execution"]
|
||||
)
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == mock_llm_responses["execution"]
|
||||
assert "Reasoning Plan:" in task.description
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_agent_with_reasoning(mock_llm_responses):
|
||||
def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
"""Test agent with reasoning that requires refinement."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
@@ -61,21 +61,19 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
max_reasoning_attempts=2,
|
||||
verbose=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Complex math task: What's the derivative of x²?",
|
||||
expected_output="The answer should be a mathematical expression.",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call(messages, *args, **kwargs):
|
||||
if any(
|
||||
"create a detailed plan" in msg.get("content", "") for msg in messages
|
||||
) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return mock_llm_responses["not_ready"]
|
||||
@@ -83,11 +81,11 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
return mock_llm_responses["ready_after_refine"]
|
||||
else:
|
||||
return "2x"
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "2x"
|
||||
assert call_count[0] == 2 # Should have made 2 reasoning calls
|
||||
assert "Reasoning Plan:" in task.description
|
||||
@@ -96,7 +94,7 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
def test_agent_with_reasoning_max_attempts_reached():
|
||||
"""Test agent with reasoning that reaches max attempts without being ready."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
@@ -104,53 +102,52 @@ def test_agent_with_reasoning_max_attempts_reached():
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
max_reasoning_attempts=2,
|
||||
verbose=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Complex math task: Solve the Riemann hypothesis.",
|
||||
expected_output="A proof or disproof of the hypothesis.",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call(messages, *args, **kwargs):
|
||||
if any(
|
||||
"create a detailed plan" in msg.get("content", "") for msg in messages
|
||||
) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
call_count[0] += 1
|
||||
return f"Attempt {call_count[0]}: I need more time to think.\n\nNOT READY: I need to refine my plan further."
|
||||
else:
|
||||
return "This is an unsolved problem in mathematics."
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "This is an unsolved problem in mathematics."
|
||||
assert (
|
||||
call_count[0] == 2
|
||||
) # Should have made exactly 2 reasoning calls (max_attempts)
|
||||
assert call_count[0] == 2 # Should have made exactly 2 reasoning calls (max_attempts)
|
||||
assert "Reasoning Plan:" in task.description
|
||||
|
||||
|
||||
def test_agent_reasoning_input_validation():
|
||||
"""Test input validation in AgentReasoning."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
reasoning=True
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(ValueError, match="Both task and agent must be provided"):
|
||||
AgentReasoning(task=None, agent=agent)
|
||||
|
||||
task = Task(description="Simple task", expected_output="Simple output")
|
||||
|
||||
task = Task(
|
||||
description="Simple task",
|
||||
expected_output="Simple output"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Both task and agent must be provided"):
|
||||
AgentReasoning(task=task, agent=None)
|
||||
|
||||
@@ -158,33 +155,33 @@ def test_agent_reasoning_input_validation():
|
||||
def test_agent_reasoning_error_handling():
|
||||
"""Test error handling during the reasoning process."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
reasoning=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Task that will cause an error",
|
||||
expected_output="Output that will never be generated",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call_error(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 2: # First calls are for reasoning
|
||||
raise Exception("LLM error during reasoning")
|
||||
return "Fallback execution result" # Return a value for task execution
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call_error
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "Fallback execution result"
|
||||
assert call_count[0] > 2 # Ensure we called the mock multiple times
|
||||
|
||||
@@ -192,36 +189,37 @@ def test_agent_reasoning_error_handling():
|
||||
def test_agent_with_function_calling():
|
||||
"""Test agent with reasoning using function calling."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
agent.llm.supports_function_calling = lambda: True
|
||||
|
||||
|
||||
def mock_function_call(messages, *args, **kwargs):
|
||||
if "tools" in kwargs:
|
||||
return json.dumps(
|
||||
{"plan": "I'll solve this simple math problem: 2+2=4.", "ready": True}
|
||||
)
|
||||
return json.dumps({
|
||||
"plan": "I'll solve this simple math problem: 2+2=4.",
|
||||
"ready": True
|
||||
})
|
||||
else:
|
||||
return "4"
|
||||
|
||||
|
||||
agent.llm.call = mock_function_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "4"
|
||||
assert "Reasoning Plan:" in task.description
|
||||
assert "I'll solve this simple math problem: 2+2=4." in task.description
|
||||
@@ -230,34 +228,34 @@ def test_agent_with_function_calling():
|
||||
def test_agent_with_function_calling_fallback():
|
||||
"""Test agent with reasoning using function calling that falls back to text parsing."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
|
||||
agent.llm.supports_function_calling = lambda: True
|
||||
|
||||
|
||||
def mock_function_call(messages, *args, **kwargs):
|
||||
if "tools" in kwargs:
|
||||
return "Invalid JSON that will trigger fallback. READY: I am ready to execute the task."
|
||||
else:
|
||||
return "4"
|
||||
|
||||
|
||||
agent.llm.call = mock_function_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "4"
|
||||
assert "Reasoning Plan:" in task.description
|
||||
assert "Invalid JSON that will trigger fallback" in task.description
|
||||
@@ -23,7 +23,6 @@ from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
# Store original environment variables
|
||||
original_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
@@ -236,7 +235,7 @@ def test_logging_tool_usage():
|
||||
)
|
||||
|
||||
assert agent.llm.model == "gpt-4o-mini"
|
||||
assert agent.tools_handler.last_used_tool is None
|
||||
assert agent.tools_handler.last_used_tool == {}
|
||||
task = Task(
|
||||
description="What is 3 times 4?",
|
||||
agent=agent,
|
||||
@@ -594,17 +593,42 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = (
|
||||
captured.out.replace("\n", " ")
|
||||
.replace(" ", " ")
|
||||
.strip()
|
||||
.replace("╭", "")
|
||||
.replace("╮", "")
|
||||
.replace("╯", "")
|
||||
.replace("╰", "")
|
||||
.replace("│", "")
|
||||
.replace("─", "")
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.replace("bold", "")
|
||||
.replace("blue", "")
|
||||
.replace("yellow", "")
|
||||
.replace("green", "")
|
||||
.replace("red", "")
|
||||
.replace("dim", "")
|
||||
.replace("🤖", "")
|
||||
.replace("🔧", "")
|
||||
.replace("✅", "")
|
||||
.replace("\x1b[93m", "")
|
||||
.replace("\x1b[00m", "")
|
||||
.replace("\\", "")
|
||||
.replace('"', "")
|
||||
.replace("'", "")
|
||||
)
|
||||
|
||||
# More flexible check, look for either the repeated usage message or verification that max iterations was reached
|
||||
output_lower = captured.out.lower()
|
||||
|
||||
has_repeated_usage_message = "tried reusing the same input" in output_lower
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
# Look for the message in the normalized output, handling the apostrophe difference
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input"
|
||||
)
|
||||
|
||||
assert (
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -759,10 +783,10 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
|
||||
with patch.object(RPMController, "_wait_for_next_minute") as moveon:
|
||||
moveon.return_value = True
|
||||
result = crew.kickoff()
|
||||
# Verify the crew executed and RPM limit was triggered
|
||||
assert result is not None
|
||||
assert moveon.called
|
||||
crew.kickoff()
|
||||
captured = capsys.readouterr()
|
||||
assert "get_final_answer" in captured.out
|
||||
assert "Max RPM reached, waiting for next minute to start." in captured.out
|
||||
moveon.assert_called_once()
|
||||
|
||||
|
||||
@@ -1189,13 +1213,17 @@ Thought:<|eot_id|>
|
||||
def test_task_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=True,
|
||||
allow_crewai_trigger_context=True
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff({"crewai_trigger_payload": "Important context data"})
|
||||
@@ -1210,13 +1238,17 @@ def test_task_allow_crewai_trigger_context():
|
||||
def test_task_without_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=False,
|
||||
allow_crewai_trigger_context=False
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
@@ -1233,18 +1265,23 @@ def test_task_without_allow_crewai_trigger_context():
|
||||
def test_task_allow_crewai_trigger_context_no_payload():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=True,
|
||||
allow_crewai_trigger_context=True
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff({"other_input": "other data"})
|
||||
|
||||
|
||||
prompt = task.prompt()
|
||||
|
||||
assert "Analyze the data" in prompt
|
||||
@@ -1256,9 +1293,7 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
@@ -1266,11 +1301,12 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task],
|
||||
process=Process.hierarchical,
|
||||
manager_llm="gpt-4o",
|
||||
manager_llm="gpt-4o"
|
||||
)
|
||||
|
||||
crew.kickoff({"crewai_trigger_payload": "Initial context data"})
|
||||
@@ -1285,9 +1321,7 @@ def test_first_task_auto_inject_trigger():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
@@ -1301,7 +1335,10 @@ def test_first_task_auto_inject_trigger():
|
||||
agent=agent2,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent1, agent2], tasks=[first_task, second_task])
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task, second_task]
|
||||
)
|
||||
crew.kickoff({"crewai_trigger_payload": "Initial context data"})
|
||||
|
||||
first_prompt = first_task.prompt()
|
||||
@@ -1312,31 +1349,31 @@ def test_first_task_auto_inject_trigger():
|
||||
assert "Process secondary data" in second_prompt
|
||||
assert "Trigger Payload:" not in second_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
expected_output="Initial analysis",
|
||||
agent=agent1,
|
||||
allow_crewai_trigger_context=False,
|
||||
allow_crewai_trigger_context=False
|
||||
)
|
||||
|
||||
second_task = Task(
|
||||
description="Process secondary data",
|
||||
expected_output="Secondary analysis",
|
||||
agent=agent2,
|
||||
allow_crewai_trigger_context=True,
|
||||
allow_crewai_trigger_context=True
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent1, agent2], tasks=[first_task, second_task])
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task, second_task]
|
||||
)
|
||||
crew.kickoff({"crewai_trigger_payload": "Context data"})
|
||||
|
||||
first_prompt = first_task.prompt()
|
||||
@@ -1346,6 +1383,7 @@ def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject
|
||||
assert "Trigger Payload: Context data" in second_prompt
|
||||
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_training_handler(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
@@ -2309,13 +2347,12 @@ def mock_get_auth_token():
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
)
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
XMLSearchTool,
|
||||
CSVSearchTool,
|
||||
EnterpriseActionTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
@@ -2331,9 +2368,10 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
},
|
||||
{
|
||||
"module": "crewai_tools",
|
||||
"name": "FileReadTool",
|
||||
"init_params": {"file_path": "test.txt"},
|
||||
"name": "XMLSearchTool",
|
||||
"init_params": {"summarize": "true"},
|
||||
},
|
||||
{"module": "crewai_tools", "name": "CSVSearchTool", "init_params": {}},
|
||||
# using a tools that returns a list of BaseTools
|
||||
{
|
||||
"module": "crewai_tools",
|
||||
@@ -2358,22 +2396,23 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
assert agent.role == "test role"
|
||||
assert agent.goal == "test goal"
|
||||
assert agent.backstory == "test backstory"
|
||||
assert len(agent.tools) == 3
|
||||
assert len(agent.tools) == 4
|
||||
|
||||
assert isinstance(agent.tools[0], SerperDevTool)
|
||||
assert agent.tools[0].n_results == 30
|
||||
assert isinstance(agent.tools[1], FileReadTool)
|
||||
assert agent.tools[1].file_path == "test.txt"
|
||||
assert isinstance(agent.tools[1], XMLSearchTool)
|
||||
assert agent.tools[1].summarize
|
||||
|
||||
assert isinstance(agent.tools[2], EnterpriseActionTool)
|
||||
assert agent.tools[2].name == "test_name"
|
||||
assert isinstance(agent.tools[2], CSVSearchTool)
|
||||
assert not agent.tools[2].summarize
|
||||
|
||||
assert isinstance(agent.tools[3], EnterpriseActionTool)
|
||||
assert agent.tools[3].name == "test_name"
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import SerperDevTool
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
@@ -108,9 +108,7 @@ class TestValidateToken(unittest.TestCase):
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
def setUp(self):
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
|
||||
@@ -4,7 +4,6 @@ import unittest
|
||||
import unittest.mock
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -28,18 +27,12 @@ def in_temp_dir():
|
||||
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
# Create a temporary directory for each test to avoid token storage conflicts
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Mock the secure storage path to use the temp directory
|
||||
with patch.object(
|
||||
TokenManager, "get_secure_storage_path", return_value=Path(temp_dir)
|
||||
):
|
||||
TokenManager().save_tokens(
|
||||
"test-token", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
TokenManager().save_tokens(
|
||||
"test-token", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
@@ -27,6 +26,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.events import (
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
@@ -36,6 +36,7 @@ from crewai.utilities.events.crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.event_listener import EventListener
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
|
||||
@@ -51,7 +52,6 @@ from crewai.utilities.events.memory_events import (
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ceo():
|
||||
return Agent(
|
||||
@@ -311,6 +311,7 @@ def test_crew_creation(researcher, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_sync_task_execution(researcher, writer):
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for an article, what makes them unique and interesting.",
|
||||
@@ -849,7 +850,6 @@ def test_crew_verbose_output(researcher, writer, capsys):
|
||||
),
|
||||
]
|
||||
|
||||
# Test with verbose=True
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=tasks,
|
||||
@@ -857,25 +857,46 @@ def test_crew_verbose_output(researcher, writer, capsys):
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
crew.kickoff()
|
||||
captured = capsys.readouterr()
|
||||
|
||||
# Verify the crew executed successfully and verbose was set
|
||||
assert result is not None
|
||||
assert crew.verbose is True
|
||||
|
||||
# Test with verbose=False
|
||||
crew_quiet = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
verbose=False,
|
||||
# Filter out event listener logs (lines starting with '[')
|
||||
filtered_output = "\n".join(
|
||||
line for line in captured.out.split("\n") if not line.startswith("[")
|
||||
)
|
||||
|
||||
result_quiet = crew_quiet.kickoff()
|
||||
expected_strings = [
|
||||
"🤖 Agent Started",
|
||||
"Agent: Researcher",
|
||||
"Task: Research AI advancements.",
|
||||
"✅ Agent Final Answer",
|
||||
"Agent: Researcher",
|
||||
"🤖 Agent Started",
|
||||
"Agent: Senior Writer",
|
||||
"Task: Write about AI in healthcare.",
|
||||
"✅ Agent Final Answer",
|
||||
"Agent: Senior Writer",
|
||||
]
|
||||
|
||||
# Verify the crew executed successfully and verbose was not set
|
||||
assert result_quiet is not None
|
||||
assert crew_quiet.verbose is False
|
||||
for expected_string in expected_strings:
|
||||
assert (
|
||||
expected_string in filtered_output
|
||||
), f"Expected '{expected_string}' in output, but it was not found."
|
||||
|
||||
# Now test with verbose set to False
|
||||
crew.verbose = False
|
||||
crew._logger = Logger(verbose=False)
|
||||
event_listener = EventListener()
|
||||
event_listener.verbose = False
|
||||
event_listener.formatter.verbose = False
|
||||
crew.kickoff()
|
||||
captured = capsys.readouterr()
|
||||
filtered_output = "\n".join(
|
||||
line
|
||||
for line in captured.out.split("\n")
|
||||
if not line.startswith("[") and line.strip() and not line.startswith("\x1b")
|
||||
)
|
||||
assert filtered_output == ""
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -938,6 +959,7 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_api_calls_throttling(capsys):
|
||||
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
@@ -1513,6 +1535,7 @@ async def test_async_kickoff_for_each_async_empty_input():
|
||||
|
||||
|
||||
def test_set_agents_step_callback():
|
||||
|
||||
researcher_agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
@@ -1541,6 +1564,7 @@ def test_set_agents_step_callback():
|
||||
|
||||
|
||||
def test_dont_set_agents_step_callback_if_already_set():
|
||||
|
||||
def agent_callback(_):
|
||||
pass
|
||||
|
||||
@@ -1638,47 +1662,42 @@ def test_task_with_no_arguments():
|
||||
|
||||
|
||||
def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
except (ImportError, Exception):
|
||||
pytest.skip("crewai_tools not available or cannot be imported")
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
# Mock Docker validation for the entire test
|
||||
with patch.object(Agent, "_validate_docker_installation"):
|
||||
programmer = Agent(
|
||||
role="Programmer",
|
||||
goal="Write code to solve problems.",
|
||||
backstory="You're a programmer who loves to solve problems with code.",
|
||||
allow_delegation=False,
|
||||
allow_code_execution=True,
|
||||
)
|
||||
programmer = Agent(
|
||||
role="Programmer",
|
||||
goal="Write code to solve problems.",
|
||||
backstory="You're a programmer who loves to solve problems with code.",
|
||||
allow_delegation=False,
|
||||
allow_code_execution=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="How much is 2 + 2?",
|
||||
expected_output="The result of the sum as an integer.",
|
||||
agent=programmer,
|
||||
)
|
||||
task = Task(
|
||||
description="How much is 2 + 2?",
|
||||
expected_output="The result of the sum as an integer.",
|
||||
agent=programmer,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[programmer], tasks=[task])
|
||||
crew = Crew(agents=[programmer], tasks=[task])
|
||||
|
||||
mock_task_output = TaskOutput(
|
||||
description="Mock description", raw="mocked output", agent="mocked agent"
|
||||
)
|
||||
mock_task_output = TaskOutput(
|
||||
description="Mock description", raw="mocked output", agent="mocked agent"
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
Task, "execute_sync", return_value=mock_task_output
|
||||
) as mock_execute_sync:
|
||||
crew.kickoff()
|
||||
with patch.object(
|
||||
Task, "execute_sync", return_value=mock_task_output
|
||||
) as mock_execute_sync:
|
||||
crew.kickoff()
|
||||
|
||||
# Get the tools that were actually used in execution
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
used_tools = kwargs["tools"]
|
||||
# Get the tools that were actually used in execution
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -2009,6 +2028,7 @@ def test_crew_inputs_interpolate_both_agents_and_tasks():
|
||||
|
||||
|
||||
def test_crew_inputs_interpolate_both_agents_and_tasks_diff():
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
@@ -2040,6 +2060,7 @@ def test_crew_inputs_interpolate_both_agents_and_tasks_diff():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_does_not_interpolate_without_inputs():
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
@@ -2173,6 +2194,7 @@ def test_task_same_callback_both_on_task_and_crew():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_with_custom_caching():
|
||||
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
@@ -2452,6 +2474,7 @@ def test_multiple_conditional_tasks(researcher, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2549,6 +2572,7 @@ def test_memory_events_are_emitted():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory_with_long_term_memory():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2578,6 +2602,7 @@ def test_using_contextual_memory_with_long_term_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_warning_long_term_memory_without_entity_memory():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2613,6 +2638,7 @@ def test_warning_long_term_memory_without_entity_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_long_term_memory_with_memory_flag():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2646,6 +2672,7 @@ def test_long_term_memory_with_memory_flag():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory_with_short_term_memory():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2675,6 +2702,7 @@ def test_using_contextual_memory_with_short_term_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_disabled_memory_using_contextual_memory():
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2801,6 +2829,7 @@ def test_crew_output_file_validation_failures():
|
||||
|
||||
|
||||
def test_manager_agent(researcher, writer):
|
||||
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -3828,9 +3857,7 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -4432,6 +4459,7 @@ def test_crew_copy_with_memory():
|
||||
original_entity_id = id(crew._entity_memory) if crew._entity_memory else None
|
||||
original_external_id = id(crew._external_memory) if crew._external_memory else None
|
||||
|
||||
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
@@ -4481,6 +4509,7 @@ def test_crew_copy_with_memory():
|
||||
or crew_copy._external_memory is None
|
||||
), "Copied _external_memory should be None if not originally present"
|
||||
|
||||
|
||||
except pydantic_core.ValidationError as e:
|
||||
if "Input should be an instance of" in str(e) and ("Memory" in str(e)):
|
||||
pytest.fail(
|
||||
@@ -4697,7 +4726,6 @@ def test_reset_agent_knowledge_with_only_agent_knowledge(researcher, writer):
|
||||
[mock_ks_research, mock_ks_writer]
|
||||
)
|
||||
|
||||
|
||||
def test_default_crew_name(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
@@ -4738,18 +4766,9 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
crew.kickoff()
|
||||
|
||||
expected_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Researcher. You're an expert in research and you love to learn new things.\nYour personal goal is: You research about math.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\nCurrent Task: Research a topic to teach a kid aged 6 about math.\n\nThis is the expected criteria for your final answer: A topic, explanation, angle, and examples.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I now can give a great answer \nFinal Answer: \n\n**Topic: Understanding Shapes (Geometry)**\n\n**Explanation:** \nShapes are everywhere around us! They are the special forms that we can see in everyday objects. Teaching a 6-year-old about shapes is not only fun but also a way to help them think about the world around them and develop their spatial awareness. We will focus on basic shapes: circle, square, triangle, and rectangle. Understanding these shapes helps kids recognize and describe their environment.\n\n**Angle:** \nLet’s make learning about shapes an adventure! We can turn it into a treasure hunt where the child has to find objects around the house or outside that match the shapes we learn. This hands-on approach helps make the learning stick!\n\n**Examples:** \n1. **Circle:** \n - Explanation: A circle is round and has no corners. It looks like a wheel or a cookie! \n - Activity: Find objects that are circles, such as a clock, a dinner plate, or a ball. Draw a big circle on a paper and then try to draw smaller circles inside it.\n\n2. **Square:** \n - Explanation: A square has four equal sides and four corners. It looks like a box! \n - Activity: Look for squares in books, in windows, or in building blocks. Try to build a tall tower using square blocks!\n\n3. **Triangle:** \n - Explanation: A triangle has three sides and three corners. It looks like a slice of pizza or a roof! \n - Activity: Use crayons to draw a big triangle and then find things that are shaped like a triangle, like a slice of cheese or a traffic sign.\n\n4. **Rectangle:** \n - Explanation: A rectangle has four sides but only opposite sides are equal. It’s like a stretched square! \n - Activity: Search for rectangles, such as a book cover or a door. You can cut out rectangles from colored paper and create a collage!\n\nBy relating the shapes to fun activities and using real-world examples, we not only make learning more enjoyable but also help the child better remember and understand the concept of shapes in math. This foundation forms the basis of their future learning in geometry!",
|
||||
},
|
||||
{'role': 'system', 'content': "You are Researcher. You're an expert in research and you love to learn new things.\nYour personal goal is: You research about math.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!"},
|
||||
{'role': 'user', 'content': '\nCurrent Task: Research a topic to teach a kid aged 6 about math.\n\nThis is the expected criteria for your final answer: A topic, explanation, angle, and examples.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'},
|
||||
{'role': 'assistant', 'content': 'I now can give a great answer \nFinal Answer: \n\n**Topic: Understanding Shapes (Geometry)**\n\n**Explanation:** \nShapes are everywhere around us! They are the special forms that we can see in everyday objects. Teaching a 6-year-old about shapes is not only fun but also a way to help them think about the world around them and develop their spatial awareness. We will focus on basic shapes: circle, square, triangle, and rectangle. Understanding these shapes helps kids recognize and describe their environment.\n\n**Angle:** \nLet’s make learning about shapes an adventure! We can turn it into a treasure hunt where the child has to find objects around the house or outside that match the shapes we learn. This hands-on approach helps make the learning stick!\n\n**Examples:** \n1. **Circle:** \n - Explanation: A circle is round and has no corners. It looks like a wheel or a cookie! \n - Activity: Find objects that are circles, such as a clock, a dinner plate, or a ball. Draw a big circle on a paper and then try to draw smaller circles inside it.\n\n2. **Square:** \n - Explanation: A square has four equal sides and four corners. It looks like a box! \n - Activity: Look for squares in books, in windows, or in building blocks. Try to build a tall tower using square blocks!\n\n3. **Triangle:** \n - Explanation: A triangle has three sides and three corners. It looks like a slice of pizza or a roof! \n - Activity: Use crayons to draw a big triangle and then find things that are shaped like a triangle, like a slice of cheese or a traffic sign.\n\n4. **Rectangle:** \n - Explanation: A rectangle has four sides but only opposite sides are equal. It’s like a stretched square! \n - Activity: Search for rectangles, such as a book cover or a door. You can cut out rectangles from colored paper and create a collage!\n\nBy relating the shapes to fun activities and using real-world examples, we not only make learning more enjoyable but also help the child better remember and understand the concept of shapes in math. This foundation forms the basis of their future learning in geometry!'}
|
||||
]
|
||||
external_memory_save.assert_called_once_with(
|
||||
value=ANY,
|
||||
@@ -3,7 +3,6 @@ from unittest.mock import MagicMock
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class BaseEvaluationMetricsTest:
|
||||
@pytest.fixture
|
||||
def mock_agent(self):
|
||||
@@ -25,5 +24,5 @@ class BaseEvaluationMetricsTest:
|
||||
def execution_trace(self):
|
||||
return {
|
||||
"thinking": ["I need to analyze this data carefully"],
|
||||
"actions": ["Gathered information", "Analyzed data"],
|
||||
}
|
||||
"actions": ["Gathered information", "Analyzed data"]
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tests.experimental.evaluation.metrics.test_base_evaluation_metrics import (
|
||||
BaseEvaluationMetricsTest,
|
||||
)
|
||||
from tests.experimental.evaluation.metrics.base_evaluation_metrics_test import BaseEvaluationMetricsTest
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import EvaluationScore
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
|
||||
@@ -10,9 +8,7 @@ from crewai.utilities.llm_utils import LLM
|
||||
|
||||
class TestGoalAlignmentEvaluator(BaseEvaluationMetricsTest):
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_evaluate_success(
|
||||
self, mock_create_llm, mock_agent, mock_task, execution_trace
|
||||
):
|
||||
def test_evaluate_success(self, mock_create_llm, mock_agent, mock_task, execution_trace):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = """
|
||||
{
|
||||
@@ -28,7 +24,7 @@ class TestGoalAlignmentEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="This is the final output",
|
||||
final_output="This is the final output"
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
@@ -44,9 +40,7 @@ class TestGoalAlignmentEvaluator(BaseEvaluationMetricsTest):
|
||||
assert mock_task.description in prompt[1]["content"]
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_evaluate_error_handling(
|
||||
self, mock_create_llm, mock_agent, mock_task, execution_trace
|
||||
):
|
||||
def test_evaluate_error_handling(self, mock_create_llm, mock_agent, mock_task, execution_trace):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = "Invalid JSON response"
|
||||
mock_create_llm.return_value = mock_llm
|
||||
@@ -57,7 +51,7 @@ class TestGoalAlignmentEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="This is the final output",
|
||||
final_output="This is the final output"
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
|
||||
@@ -6,13 +6,10 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
||||
ReasoningEfficiencyEvaluator,
|
||||
)
|
||||
from tests.experimental.evaluation.metrics.test_base_evaluation_metrics import (
|
||||
BaseEvaluationMetricsTest,
|
||||
)
|
||||
from tests.experimental.evaluation.metrics.base_evaluation_metrics_test import BaseEvaluationMetricsTest
|
||||
from crewai.utilities.llm_utils import LLM
|
||||
from crewai.experimental.evaluation.base_evaluator import EvaluationScore
|
||||
|
||||
|
||||
class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
@pytest.fixture
|
||||
def mock_output(self):
|
||||
@@ -26,18 +23,18 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
{
|
||||
"prompt": "How should I approach this task?",
|
||||
"response": "I'll first research the topic, then compile findings.",
|
||||
"timestamp": 1626987654,
|
||||
"timestamp": 1626987654
|
||||
},
|
||||
{
|
||||
"prompt": "What resources should I use?",
|
||||
"response": "I'll use relevant academic papers and reliable websites.",
|
||||
"timestamp": 1626987754,
|
||||
"timestamp": 1626987754
|
||||
},
|
||||
{
|
||||
"prompt": "How should I structure the output?",
|
||||
"response": "I'll organize information clearly with headings and bullet points.",
|
||||
"timestamp": 1626987854,
|
||||
},
|
||||
"timestamp": 1626987854
|
||||
}
|
||||
]
|
||||
|
||||
def test_insufficient_llm_calls(self, mock_agent, mock_task, mock_output):
|
||||
@@ -48,7 +45,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output=mock_output,
|
||||
final_output=mock_output
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
@@ -56,9 +53,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
assert "Insufficient LLM calls" in result.feedback
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_successful_evaluation(
|
||||
self, mock_create_llm, mock_agent, mock_task, mock_output, llm_calls
|
||||
):
|
||||
def test_successful_evaluation(self, mock_create_llm, mock_agent, mock_task, mock_output, llm_calls):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = """
|
||||
{
|
||||
@@ -88,7 +83,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output=mock_output,
|
||||
final_output=mock_output
|
||||
)
|
||||
|
||||
# Assertions
|
||||
@@ -102,9 +97,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
mock_llm.call.assert_called_once()
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_parse_error_handling(
|
||||
self, mock_create_llm, mock_agent, mock_task, mock_output, llm_calls
|
||||
):
|
||||
def test_parse_error_handling(self, mock_create_llm, mock_agent, mock_task, mock_output, llm_calls):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = "Invalid JSON response"
|
||||
mock_create_llm.return_value = mock_llm
|
||||
@@ -121,7 +114,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output=mock_output,
|
||||
final_output=mock_output
|
||||
)
|
||||
|
||||
# Assertions for error handling
|
||||
@@ -133,31 +126,11 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
def test_loop_detection(self, mock_create_llm, mock_agent, mock_task, mock_output):
|
||||
# Setup LLM calls with a repeating pattern
|
||||
repetitive_llm_calls = [
|
||||
{
|
||||
"prompt": "How to solve?",
|
||||
"response": "I'll try method A",
|
||||
"timestamp": 1000,
|
||||
},
|
||||
{
|
||||
"prompt": "Let me try method A",
|
||||
"response": "It didn't work",
|
||||
"timestamp": 1100,
|
||||
},
|
||||
{
|
||||
"prompt": "How to solve?",
|
||||
"response": "I'll try method A again",
|
||||
"timestamp": 1200,
|
||||
},
|
||||
{
|
||||
"prompt": "Let me try method A",
|
||||
"response": "It didn't work",
|
||||
"timestamp": 1300,
|
||||
},
|
||||
{
|
||||
"prompt": "How to solve?",
|
||||
"response": "I'll try method A one more time",
|
||||
"timestamp": 1400,
|
||||
},
|
||||
{"prompt": "How to solve?", "response": "I'll try method A", "timestamp": 1000},
|
||||
{"prompt": "Let me try method A", "response": "It didn't work", "timestamp": 1100},
|
||||
{"prompt": "How to solve?", "response": "I'll try method A again", "timestamp": 1200},
|
||||
{"prompt": "Let me try method A", "response": "It didn't work", "timestamp": 1300},
|
||||
{"prompt": "How to solve?", "response": "I'll try method A one more time", "timestamp": 1400}
|
||||
]
|
||||
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
@@ -185,7 +158,7 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output=mock_output,
|
||||
final_output=mock_output
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import EvaluationScore
|
||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
||||
SemanticQualityEvaluator,
|
||||
)
|
||||
from tests.experimental.evaluation.metrics.test_base_evaluation_metrics import (
|
||||
BaseEvaluationMetricsTest,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import SemanticQualityEvaluator
|
||||
from tests.experimental.evaluation.metrics.base_evaluation_metrics_test import BaseEvaluationMetricsTest
|
||||
from crewai.utilities.llm_utils import LLM
|
||||
|
||||
|
||||
class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_evaluate_success(
|
||||
self, mock_create_llm, mock_agent, mock_task, execution_trace
|
||||
):
|
||||
def test_evaluate_success(self, mock_create_llm, mock_agent, mock_task, execution_trace):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = """
|
||||
{
|
||||
@@ -30,7 +23,7 @@ class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="This is a well-structured analysis of the data.",
|
||||
final_output="This is a well-structured analysis of the data."
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
@@ -46,9 +39,7 @@ class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
assert mock_task.description in prompt[1]["content"]
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_evaluate_with_empty_output(
|
||||
self, mock_create_llm, mock_agent, mock_task, execution_trace
|
||||
):
|
||||
def test_evaluate_with_empty_output(self, mock_create_llm, mock_agent, mock_task, execution_trace):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = """
|
||||
{
|
||||
@@ -64,7 +55,7 @@ class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="",
|
||||
final_output=""
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
@@ -72,9 +63,7 @@ class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
assert "empty or minimal" in result.feedback
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_evaluate_error_handling(
|
||||
self, mock_create_llm, mock_agent, mock_task, execution_trace
|
||||
):
|
||||
def test_evaluate_error_handling(self, mock_create_llm, mock_agent, mock_task, execution_trace):
|
||||
mock_llm = MagicMock(spec=LLM)
|
||||
mock_llm.call.return_value = "Invalid JSON response"
|
||||
mock_create_llm.return_value = mock_llm
|
||||
@@ -85,9 +74,9 @@ class TestSemanticQualityEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="This is the output.",
|
||||
final_output="This is the output."
|
||||
)
|
||||
|
||||
assert isinstance(result, EvaluationScore)
|
||||
assert result.score is None
|
||||
assert "Failed to parse" in result.feedback
|
||||
assert "Failed to parse" in result.feedback
|
||||
@@ -3,13 +3,10 @@ from unittest.mock import patch, MagicMock
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
from crewai.utilities.llm_utils import LLM
|
||||
from tests.experimental.evaluation.metrics.test_base_evaluation_metrics import (
|
||||
BaseEvaluationMetricsTest,
|
||||
)
|
||||
|
||||
from tests.experimental.evaluation.metrics.base_evaluation_metrics_test import BaseEvaluationMetricsTest
|
||||
|
||||
class TestToolSelectionEvaluator(BaseEvaluationMetricsTest):
|
||||
def test_no_tools_available(self, mock_task, mock_agent):
|
||||
@@ -23,7 +20,7 @@ class TestToolSelectionEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score is None
|
||||
@@ -38,7 +35,7 @@ class TestToolSelectionEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score is None
|
||||
@@ -59,12 +56,8 @@ class TestToolSelectionEvaluator(BaseEvaluationMetricsTest):
|
||||
# Setup execution trace with tool uses
|
||||
execution_trace = {
|
||||
"tool_uses": [
|
||||
{
|
||||
"tool": "search_tool",
|
||||
"input": {"query": "test query"},
|
||||
"output": "search results",
|
||||
},
|
||||
{"tool": "calculator", "input": {"expression": "2+2"}, "output": "4"},
|
||||
{"tool": "search_tool", "input": {"query": "test query"}, "output": "search results"},
|
||||
{"tool": "calculator", "input": {"expression": "2+2"}, "output": "4"}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -73,7 +66,7 @@ class TestToolSelectionEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score == 8.5
|
||||
@@ -97,7 +90,7 @@ class TestParameterExtractionEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score is None
|
||||
@@ -124,14 +117,14 @@ class TestParameterExtractionEvaluator(BaseEvaluationMetricsTest):
|
||||
"tool": "search_tool",
|
||||
"input": {"query": "test query"},
|
||||
"output": "search results",
|
||||
"error": None,
|
||||
"error": None
|
||||
},
|
||||
{
|
||||
"tool": "calculator",
|
||||
"input": {"expression": "2+2"},
|
||||
"output": "4",
|
||||
"error": None,
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -140,7 +133,7 @@ class TestParameterExtractionEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score == 9.0
|
||||
@@ -156,7 +149,7 @@ class TestToolInvocationEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score is None
|
||||
@@ -178,12 +171,8 @@ class TestToolInvocationEvaluator(BaseEvaluationMetricsTest):
|
||||
# Setup execution trace with tool uses
|
||||
execution_trace = {
|
||||
"tool_uses": [
|
||||
{
|
||||
"tool": "search_tool",
|
||||
"input": {"query": "test query"},
|
||||
"output": "search results",
|
||||
},
|
||||
{"tool": "calculator", "input": {"expression": "2+2"}, "output": "4"},
|
||||
{"tool": "search_tool", "input": {"query": "test query"}, "output": "search results"},
|
||||
{"tool": "calculator", "input": {"expression": "2+2"}, "output": "4"}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -192,7 +181,7 @@ class TestToolInvocationEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score == 8.0
|
||||
@@ -218,14 +207,14 @@ class TestToolInvocationEvaluator(BaseEvaluationMetricsTest):
|
||||
"tool": "search_tool",
|
||||
"input": {"query": "test query"},
|
||||
"output": "search results",
|
||||
"error": None,
|
||||
"error": None
|
||||
},
|
||||
{
|
||||
"tool": "calculator",
|
||||
"input": {"expression": "2+"},
|
||||
"output": None,
|
||||
"error": "Invalid expression",
|
||||
},
|
||||
"error": "Invalid expression"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -234,7 +223,7 @@ class TestToolInvocationEvaluator(BaseEvaluationMetricsTest):
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
execution_trace=execution_trace,
|
||||
final_output="Final output",
|
||||
final_output="Final output"
|
||||
)
|
||||
|
||||
assert result.score == 5.5
|
||||
|
||||
@@ -616,9 +616,7 @@ def test_async_flow_with_trigger_payload():
|
||||
flow = AsyncTriggerFlow()
|
||||
|
||||
test_payload = "Async trigger data"
|
||||
result = asyncio.run(
|
||||
flow.kickoff_async(inputs={"crewai_trigger_payload": test_payload})
|
||||
)
|
||||
result = asyncio.run(flow.kickoff_async(inputs={"crewai_trigger_payload": test_payload}))
|
||||
|
||||
assert captured_payload == [test_payload, "async_started"]
|
||||
assert result == "async_finished"
|
||||
@@ -4,12 +4,12 @@
|
||||
def test_task_output_import():
|
||||
"""Test that TaskOutput can be imported from crewai."""
|
||||
from crewai import TaskOutput
|
||||
|
||||
|
||||
assert TaskOutput is not None
|
||||
|
||||
|
||||
|
||||
|
||||
def test_crew_output_import():
|
||||
"""Test that CrewOutput can be imported from crewai."""
|
||||
from crewai import CrewOutput
|
||||
|
||||
|
||||
assert CrewOutput is not None
|
||||
@@ -17,7 +17,6 @@ from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
@@ -213,7 +212,6 @@ def custom_storage():
|
||||
custom_storage = CustomStorage()
|
||||
return custom_storage
|
||||
|
||||
|
||||
def test_external_memory_custom_storage(custom_storage, crew_with_external_memory):
|
||||
external_memory = ExternalMemory(storage=custom_storage)
|
||||
|
||||
@@ -235,14 +233,12 @@ def test_external_memory_custom_storage(custom_storage, crew_with_external_memor
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_external_memory_search_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
|
||||
def test_external_memory_search_events(custom_storage, external_memory_with_mocked_config):
|
||||
events = defaultdict(list)
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@@ -262,39 +258,37 @@ def test_external_memory_search_events(
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
|
||||
assert dict(events["MemoryQueryStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"query": "test value",
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_query_started',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'external_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'query': 'test value',
|
||||
'limit': 3,
|
||||
'score_threshold': 0.35
|
||||
}
|
||||
|
||||
assert dict(events["MemoryQueryCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"query": "test value",
|
||||
"results": [],
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
"query_time_ms": ANY,
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_query_completed',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'external_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'query': 'test value',
|
||||
'results': [],
|
||||
'limit': 3,
|
||||
'score_threshold': 0.35,
|
||||
'query_time_ms': ANY
|
||||
}
|
||||
|
||||
|
||||
def test_external_memory_save_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
|
||||
def test_external_memory_save_events(custom_storage, external_memory_with_mocked_config):
|
||||
events = defaultdict(list)
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@@ -314,24 +308,24 @@ def test_external_memory_save_events(
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
|
||||
assert dict(events["MemorySaveStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"value": "saving value",
|
||||
"metadata": {"task": "test_task"},
|
||||
"agent_role": "test_agent",
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_save_started',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'external_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'value': 'saving value',
|
||||
'metadata': {'task': 'test_task'},
|
||||
'agent_role': "test_agent"
|
||||
}
|
||||
|
||||
assert dict(events["MemorySaveCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"value": "saving value",
|
||||
"metadata": {"task": "test_task", "agent": "test_agent"},
|
||||
"agent_role": "test_agent",
|
||||
"save_time_ms": ANY,
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_save_completed',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'external_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'value': 'saving value',
|
||||
'metadata': {'task': 'test_task', 'agent': 'test_agent'},
|
||||
'agent_role': "test_agent",
|
||||
'save_time_ms': ANY
|
||||
}
|
||||
@@ -11,7 +11,6 @@ from crewai.utilities.events.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
@@ -22,7 +21,6 @@ def test_long_term_memory_save_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@@ -62,12 +60,7 @@ def test_long_term_memory_save_events(long_term_memory):
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"value": "test_task",
|
||||
"metadata": {
|
||||
"task": "test_task",
|
||||
"quality": 0.5,
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
},
|
||||
"metadata": {"task": "test_task", "quality": 0.5, "agent": "test_agent", "expected_output": "test_output"},
|
||||
"agent_role": "test_agent",
|
||||
"save_time_ms": ANY,
|
||||
}
|
||||
@@ -77,7 +70,6 @@ def test_long_term_memory_search_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@@ -88,34 +80,37 @@ def test_long_term_memory_search_events(long_term_memory):
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
long_term_memory.search(
|
||||
test_query,
|
||||
latest_n=5
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
|
||||
assert dict(events["MemoryQueryStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"query": "test query",
|
||||
"limit": 5,
|
||||
"score_threshold": None,
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_query_started',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'long_term_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'query': 'test query',
|
||||
'limit': 5,
|
||||
'score_threshold': None
|
||||
}
|
||||
|
||||
assert dict(events["MemoryQueryCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"query": "test query",
|
||||
"results": None,
|
||||
"limit": 5,
|
||||
"score_threshold": None,
|
||||
"query_time_ms": ANY,
|
||||
'timestamp': ANY,
|
||||
'type': 'memory_query_completed',
|
||||
'source_fingerprint': None,
|
||||
'source_type': 'long_term_memory',
|
||||
'fingerprint_metadata': None,
|
||||
'query': 'test query',
|
||||
'results': None,
|
||||
'limit': 5,
|
||||
'score_threshold': None,
|
||||
'query_time_ms': ANY
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from crewai.project import (
|
||||
from crewai.task import Task
|
||||
from crewai.tools import tool
|
||||
|
||||
|
||||
class SimpleCrew:
|
||||
@agent
|
||||
def simple_agent(self):
|
||||
@@ -86,24 +85,17 @@ class InternalCrew:
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=self.tasks, verbose=True)
|
||||
|
||||
|
||||
@CrewBase
|
||||
class InternalCrewWithMCP(InternalCrew):
|
||||
mcp_server_params = {"host": "localhost", "port": 8000}
|
||||
|
||||
@agent
|
||||
def reporting_analyst(self):
|
||||
return Agent(
|
||||
config=self.agents_config["reporting_analyst"], tools=self.get_mcp_tools()
|
||||
) # type: ignore[index]
|
||||
return Agent(config=self.agents_config["reporting_analyst"], tools=self.get_mcp_tools()) # type: ignore[index]
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(
|
||||
config=self.agents_config["researcher"],
|
||||
tools=self.get_mcp_tools("simple_tool"),
|
||||
) # type: ignore[index]
|
||||
|
||||
return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool")) # type: ignore[index]
|
||||
|
||||
def test_agent_memoization():
|
||||
crew = SimpleCrew()
|
||||
@@ -253,18 +245,15 @@ def test_multiple_before_after_kickoff():
|
||||
assert "processed first" in result.raw, "First after_kickoff not executed"
|
||||
assert "processed second" in result.raw, "Second after_kickoff not executed"
|
||||
|
||||
|
||||
def test_crew_name():
|
||||
crew = InternalCrew()
|
||||
assert crew._crew_name == "InternalCrew"
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool():
|
||||
"""Return 'Hi!'"""
|
||||
return "Hi!"
|
||||
|
||||
|
||||
@tool
|
||||
def another_simple_tool():
|
||||
"""Return 'Hi!'"""
|
||||
@@ -272,11 +261,8 @@ def another_simple_tool():
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from crewai_tools.adapters.mcp_adapter import ToolCollection
|
||||
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from crewai_tools.adapters.mcp_adapter import ToolCollection
|
||||
mock = Mock(spec=MCPServerAdapter)
|
||||
mock.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
|
||||
@@ -284,4 +270,4 @@ def test_internal_crew_with_mcp():
|
||||
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
|
||||
assert crew.researcher().tools == [simple_tool]
|
||||
|
||||
adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000})
|
||||
adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000})
|
||||
@@ -1,550 +0,0 @@
|
||||
"""Tests for ChromaDBClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chromadb_client():
|
||||
"""Create a mock ChromaDB client."""
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
return Mock(spec=ClientAPI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_chromadb_client():
|
||||
"""Create a mock async ChromaDB client."""
|
||||
from chromadb.api import AsyncClientAPI
|
||||
|
||||
return Mock(spec=AsyncClientAPI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance for testing."""
|
||||
client = ChromaDBClient()
|
||||
client.client = mock_chromadb_client
|
||||
client.embedding_function = Mock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client for testing."""
|
||||
client = ChromaDBClient()
|
||||
client.client = mock_async_chromadb_client
|
||||
client.embedding_function = Mock()
|
||||
return client
|
||||
|
||||
|
||||
class TestChromaDBClient:
|
||||
"""Test suite for ChromaDBClient."""
|
||||
|
||||
def test_create_collection(self, client, mock_chromadb_client):
|
||||
"""Test that create_collection calls the underlying client correctly."""
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=client.embedding_function,
|
||||
data_loader=None,
|
||||
get_or_create=False,
|
||||
)
|
||||
|
||||
def test_create_collection_with_all_params(self, client, mock_chromadb_client):
|
||||
"""Test create_collection with all optional parameters."""
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
client.create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that acreate_collection calls the underlying client correctly."""
|
||||
# Make the mock's create_collection an AsyncMock
|
||||
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=async_client.embedding_function,
|
||||
data_loader=None,
|
||||
get_or_create=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_with_all_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test acreate_collection with all optional parameters."""
|
||||
# Make the mock's create_collection an AsyncMock
|
||||
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
await async_client.acreate_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
def test_get_or_create_collection(self, client, mock_chromadb_client):
|
||||
"""Test that get_or_create_collection calls the underlying client correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=client.embedding_function,
|
||||
data_loader=None,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
def test_get_or_create_collection_with_all_params(
|
||||
self, client, mock_chromadb_client
|
||||
):
|
||||
"""Test get_or_create_collection with all optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
result = client.get_or_create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aget_or_create_collection calls the underlying client correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=async_client.embedding_function,
|
||||
data_loader=None,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_with_all_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aget_or_create_collection with all optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.add.assert_called_once()
|
||||
call_args = mock_collection.add.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
|
||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with custom document IDs."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom_id_1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "custom_id_2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.add.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test that add_documents raises error for empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_collection", documents=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.add.assert_called_once()
|
||||
call_args = mock_collection.add.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_custom_ids(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with custom document IDs."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom_id_1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "custom_id_2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_collection.add.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aadd_documents raises error for empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=[]
|
||||
)
|
||||
|
||||
def test_search(self, client, mock_chromadb_client):
|
||||
"""Test that search queries the collection correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||
"distances": [[0.1, 0.3]],
|
||||
}
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "Document 1"
|
||||
assert results[0]["metadata"] == {"source": "test1"}
|
||||
assert results[0]["score"] == 0.95
|
||||
|
||||
def test_search_with_optional_params(self, client, mock_chromadb_client):
|
||||
"""Test search with optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
"metadatas": [
|
||||
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||
],
|
||||
"distances": [[0.1, 0.3, 1.5]], # Last one will be filtered by threshold
|
||||
}
|
||||
|
||||
results = client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter={"source": "test"},
|
||||
score_threshold=0.7,
|
||||
)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where={"source": "test"},
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, async_client, mock_async_chromadb_client) -> None:
|
||||
"""Test that asearch queries the collection correctly."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
return_value={
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||
"distances": [[0.1, 0.3]],
|
||||
}
|
||||
)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "Document 1"
|
||||
assert results[0]["metadata"] == {"source": "test1"}
|
||||
assert results[0]["score"] == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_with_optional_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test asearch with optional parameters."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
return_value={
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
"metadatas": [
|
||||
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||
],
|
||||
"distances": [
|
||||
[0.1, 0.3, 1.5]
|
||||
], # Last one will be filtered by threshold
|
||||
}
|
||||
)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter={"source": "test"},
|
||||
score_threshold=0.7,
|
||||
)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where={"source": "test"},
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
# Only 2 results should pass the score threshold
|
||||
assert len(results) == 2
|
||||
|
||||
def test_delete_collection(self, client, mock_chromadb_client):
|
||||
"""Test that delete_collection calls the underlying client correctly."""
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.delete_collection.assert_called_once_with(
|
||||
name="test_collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that adelete_collection calls the underlying client correctly."""
|
||||
mock_async_chromadb_client.delete_collection = AsyncMock(return_value=None)
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_chromadb_client.delete_collection.assert_called_once_with(
|
||||
name="test_collection"
|
||||
)
|
||||
|
||||
def test_reset(self, client, mock_chromadb_client):
|
||||
"""Test that reset calls the underlying client correctly."""
|
||||
mock_chromadb_client.reset.return_value = True
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_chromadb_client.reset.assert_called_once_with()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset(self, async_client, mock_async_chromadb_client) -> None:
|
||||
"""Test that areset calls the underlying client correctly."""
|
||||
mock_async_chromadb_client.reset = AsyncMock(return_value=True)
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||
@@ -345,8 +345,6 @@ def test_output_pydantic_hierarchical():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_output_json_sequential():
|
||||
import uuid
|
||||
|
||||
class ScoreOutput(BaseModel):
|
||||
score: int
|
||||
|
||||
@@ -357,12 +355,11 @@ def test_output_json_sequential():
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
output_file = f"score_{uuid.uuid4()}.json"
|
||||
task = Task(
|
||||
description="Give me an integer score between 1-5 for the following title: 'The impact of AI in the future of work'",
|
||||
expected_output="The score of the title.",
|
||||
output_json=ScoreOutput,
|
||||
output_file=output_file,
|
||||
output_file="score.json",
|
||||
agent=scorer,
|
||||
)
|
||||
|
||||
@@ -371,9 +368,6 @@ def test_output_json_sequential():
|
||||
assert '{"score": 4}' == result.json
|
||||
assert result.to_dict() == {"score": 4}
|
||||
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_output_json_hierarchical():
|
||||
@@ -404,7 +398,6 @@ def test_output_json_hierarchical():
|
||||
assert result.json == '{"score": 4}'
|
||||
assert result.to_dict() == {"score": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_inject_date():
|
||||
reporter = Agent(
|
||||
@@ -429,7 +422,6 @@ def test_inject_date():
|
||||
result = crew.kickoff()
|
||||
assert "2025-05-21" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_inject_date_custom_format():
|
||||
reporter = Agent(
|
||||
@@ -455,7 +447,6 @@ def test_inject_date_custom_format():
|
||||
result = crew.kickoff()
|
||||
assert "May 21, 2025" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_no_inject_date():
|
||||
reporter = Agent(
|
||||
@@ -659,8 +650,6 @@ def test_save_task_output():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_save_task_json_output():
|
||||
from unittest.mock import patch
|
||||
|
||||
class ScoreOutput(BaseModel):
|
||||
score: int
|
||||
|
||||
@@ -680,25 +669,17 @@ def test_save_task_json_output():
|
||||
)
|
||||
|
||||
crew = Crew(agents=[scorer], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Mock only the _save_file method to avoid actual file I/O
|
||||
with patch.object(Task, "_save_file") as mock_save:
|
||||
result = crew.kickoff()
|
||||
assert result is not None
|
||||
mock_save.assert_called_once()
|
||||
|
||||
call_args = mock_save.call_args
|
||||
if call_args:
|
||||
saved_content = call_args[0][0]
|
||||
if isinstance(saved_content, str):
|
||||
data = json.loads(saved_content)
|
||||
assert "score" in data
|
||||
output_file_exists = os.path.exists("score.json")
|
||||
assert output_file_exists
|
||||
assert {"score": 4} == json.loads(open("score.json").read())
|
||||
if output_file_exists:
|
||||
os.remove("score.json")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_save_task_pydantic_output():
|
||||
import uuid
|
||||
|
||||
class ScoreOutput(BaseModel):
|
||||
score: int
|
||||
|
||||
@@ -709,11 +690,10 @@ def test_save_task_pydantic_output():
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
output_file = f"score_{uuid.uuid4()}.json"
|
||||
task = Task(
|
||||
description="Give me an integer score between 1-5 for the following title: 'The impact of AI in the future of work'",
|
||||
expected_output="The score of the title.",
|
||||
output_file=output_file,
|
||||
output_file="score.json",
|
||||
output_pydantic=ScoreOutput,
|
||||
agent=scorer,
|
||||
)
|
||||
@@ -721,11 +701,11 @@ def test_save_task_pydantic_output():
|
||||
crew = Crew(agents=[scorer], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
output_file_exists = os.path.exists(output_file)
|
||||
output_file_exists = os.path.exists("score.json")
|
||||
assert output_file_exists
|
||||
assert {"score": 4} == json.loads(open(output_file).read())
|
||||
assert {"score": 4} == json.loads(open("score.json").read())
|
||||
if output_file_exists:
|
||||
os.remove(output_file)
|
||||
os.remove("score.json")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1156,67 +1136,62 @@ def test_output_file_validation():
|
||||
def test_create_directory_true():
|
||||
"""Test that directories are created when create_directory=True."""
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
output_path = "test_create_dir/output.txt"
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file=output_path,
|
||||
create_directory=True,
|
||||
)
|
||||
|
||||
|
||||
resolved_path = Path(output_path).expanduser().resolve()
|
||||
resolved_dir = resolved_path.parent
|
||||
|
||||
|
||||
if resolved_path.exists():
|
||||
resolved_path.unlink()
|
||||
if resolved_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(resolved_dir)
|
||||
|
||||
|
||||
assert not resolved_dir.exists()
|
||||
|
||||
|
||||
task._save_file("test content")
|
||||
|
||||
|
||||
assert resolved_dir.exists()
|
||||
assert resolved_path.exists()
|
||||
|
||||
|
||||
if resolved_path.exists():
|
||||
resolved_path.unlink()
|
||||
if resolved_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(resolved_dir)
|
||||
|
||||
|
||||
def test_create_directory_false():
|
||||
"""Test that directories are not created when create_directory=False."""
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
output_path = "nonexistent_test_dir/output.txt"
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file=output_path,
|
||||
create_directory=False,
|
||||
)
|
||||
|
||||
|
||||
resolved_path = Path(output_path).expanduser().resolve()
|
||||
resolved_dir = resolved_path.parent
|
||||
|
||||
|
||||
if resolved_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(resolved_dir)
|
||||
|
||||
|
||||
assert not resolved_dir.exists()
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Directory .* does not exist and create_directory is False"
|
||||
):
|
||||
|
||||
with pytest.raises(RuntimeError, match="Directory .* does not exist and create_directory is False"):
|
||||
task._save_file("test content")
|
||||
|
||||
|
||||
@@ -1227,35 +1202,34 @@ def test_create_directory_default():
|
||||
expected_output="Test output",
|
||||
output_file="output.txt",
|
||||
)
|
||||
|
||||
|
||||
assert task.create_directory is True
|
||||
|
||||
|
||||
def test_create_directory_with_existing_directory():
|
||||
"""Test that create_directory=False works when directory already exists."""
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
output_path = "existing_test_dir/output.txt"
|
||||
|
||||
|
||||
resolved_path = Path(output_path).expanduser().resolve()
|
||||
resolved_dir = resolved_path.parent
|
||||
resolved_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file=output_path,
|
||||
create_directory=False,
|
||||
)
|
||||
|
||||
|
||||
task._save_file("test content")
|
||||
assert resolved_path.exists()
|
||||
|
||||
|
||||
if resolved_path.exists():
|
||||
resolved_path.unlink()
|
||||
if resolved_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(resolved_dir)
|
||||
|
||||
|
||||
@@ -1267,7 +1241,7 @@ def test_github_issue_3149_reproduction():
|
||||
output_file="test_output.txt",
|
||||
create_directory=True,
|
||||
)
|
||||
|
||||
|
||||
assert task.create_directory is True
|
||||
assert task.output_file == "test_output.txt"
|
||||
|
||||
@@ -79,10 +79,8 @@ def test_telemetry_fails_due_connect_timeout(export_mock, logger_mock):
|
||||
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
assert export_mock.called
|
||||
assert logger_mock.call_count == export_mock.call_count
|
||||
for call in logger_mock.call_args_list:
|
||||
assert call[0][0] == error
|
||||
export_mock.assert_called_once()
|
||||
logger_mock.assert_called_once_with(error)
|
||||
|
||||
|
||||
@pytest.mark.telemetry
|
||||
|
||||
@@ -7,37 +7,37 @@ from crewai.task import Task
|
||||
|
||||
def test_agent_inject_date():
|
||||
"""Test that the inject_date flag injects the current date into the task.
|
||||
|
||||
|
||||
Tests that when inject_date=True, the current date is added to the task description.
|
||||
"""
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
with patch('datetime.datetime') as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="test_agent",
|
||||
goal="test_goal",
|
||||
backstory="test_backstory",
|
||||
inject_date=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert "Current Date: 2025-01-01" in task.description
|
||||
assert task.description != original_description
|
||||
|
||||
|
||||
def test_agent_without_inject_date():
|
||||
"""Test that without inject_date flag, no date is injected.
|
||||
|
||||
|
||||
Tests that when inject_date=False (default), no date is added to the task description.
|
||||
"""
|
||||
agent = Agent(
|
||||
@@ -46,28 +46,28 @@ def test_agent_without_inject_date():
|
||||
backstory="test_backstory",
|
||||
# inject_date is False by default
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert task.description == original_description
|
||||
|
||||
|
||||
def test_agent_inject_date_custom_format():
|
||||
"""Test that the inject_date flag with custom date_format works correctly.
|
||||
|
||||
|
||||
Tests that when inject_date=True with a custom date_format, the date is formatted correctly.
|
||||
"""
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
with patch('datetime.datetime') as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="test_agent",
|
||||
goal="test_goal",
|
||||
@@ -75,25 +75,25 @@ def test_agent_inject_date_custom_format():
|
||||
inject_date=True,
|
||||
date_format="%d/%m/%Y",
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert "Current Date: 01/01/2025" in task.description
|
||||
assert task.description != original_description
|
||||
|
||||
|
||||
def test_agent_inject_date_invalid_format():
|
||||
"""Test error handling with invalid date format.
|
||||
|
||||
|
||||
Tests that when an invalid date_format is provided, the task description remains unchanged.
|
||||
"""
|
||||
agent = Agent(
|
||||
@@ -103,15 +103,15 @@ def test_agent_inject_date_invalid_format():
|
||||
inject_date=True,
|
||||
date_format="invalid",
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert task.description == original_description
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Regression tests for flow listener resumability fix.
|
||||
|
||||
These tests ensure that:
|
||||
1. HITL flows can resume properly without re-executing completed methods
|
||||
2. Cyclic flows can re-execute methods on each iteration
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from crewai.flow.flow import Flow, listen, router, start
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
|
||||
def test_hitl_resumption_skips_completed_listeners(tmp_path):
|
||||
"""Test that HITL resumption skips completed listener methods but continues chains."""
|
||||
db_path = tmp_path / "test_flows.db"
|
||||
persistence = SQLiteFlowPersistence(str(db_path))
|
||||
execution_log = []
|
||||
|
||||
class HitlFlow(Flow[Dict[str, str]]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_log.append("step_1_executed")
|
||||
self.state["step1"] = "done"
|
||||
return "step1_result"
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_log.append("step_2_executed")
|
||||
self.state["step2"] = "done"
|
||||
return "step2_result"
|
||||
|
||||
@listen(step_2)
|
||||
def step_3(self):
|
||||
execution_log.append("step_3_executed")
|
||||
self.state["step3"] = "done"
|
||||
return "step3_result"
|
||||
|
||||
flow1 = HitlFlow(persistence=persistence)
|
||||
flow1.kickoff()
|
||||
flow_id = flow1.state["id"]
|
||||
|
||||
assert execution_log == ["step_1_executed", "step_2_executed", "step_3_executed"]
|
||||
|
||||
flow2 = HitlFlow(persistence=persistence)
|
||||
flow2._completed_methods = {"step_1", "step_2"} # Simulate partial completion
|
||||
execution_log.clear()
|
||||
|
||||
flow2.kickoff(inputs={"id": flow_id})
|
||||
|
||||
assert "step_1_executed" not in execution_log
|
||||
assert "step_2_executed" not in execution_log
|
||||
assert "step_3_executed" in execution_log
|
||||
|
||||
|
||||
def test_cyclic_flow_re_executes_on_each_iteration():
|
||||
"""Test that cyclic flows properly re-execute methods on each iteration."""
|
||||
execution_log = []
|
||||
|
||||
class CyclicFlowTest(Flow[Dict[str, str]]):
|
||||
iteration = 0
|
||||
max_iterations = 3
|
||||
|
||||
@start("loop")
|
||||
def step_1(self):
|
||||
if self.iteration >= self.max_iterations:
|
||||
return None
|
||||
execution_log.append(f"step_1_{self.iteration}")
|
||||
return f"result_{self.iteration}"
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_log.append(f"step_2_{self.iteration}")
|
||||
|
||||
@router(step_2)
|
||||
def step_3(self):
|
||||
execution_log.append(f"step_3_{self.iteration}")
|
||||
self.iteration += 1
|
||||
if self.iteration < self.max_iterations:
|
||||
return "loop"
|
||||
return "exit"
|
||||
|
||||
flow = CyclicFlowTest()
|
||||
flow.kickoff()
|
||||
|
||||
expected = []
|
||||
for i in range(3):
|
||||
expected.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])
|
||||
|
||||
assert execution_log == expected
|
||||
|
||||
|
||||
def test_conditional_start_with_resumption(tmp_path):
|
||||
"""Test that conditional start methods work correctly with resumption."""
|
||||
db_path = tmp_path / "test_flows.db"
|
||||
persistence = SQLiteFlowPersistence(str(db_path))
|
||||
execution_log = []
|
||||
|
||||
class ConditionalStartFlow(Flow[Dict[str, str]]):
|
||||
@start()
|
||||
def init(self):
|
||||
execution_log.append("init")
|
||||
return "initialized"
|
||||
|
||||
@router(init)
|
||||
def route_to_branch(self):
|
||||
execution_log.append("router")
|
||||
return "branch_a"
|
||||
|
||||
@start("branch_a")
|
||||
def branch_a_start(self):
|
||||
execution_log.append("branch_a_start")
|
||||
self.state["branch"] = "a"
|
||||
|
||||
@listen(branch_a_start)
|
||||
def branch_a_process(self):
|
||||
execution_log.append("branch_a_process")
|
||||
self.state["processed"] = "yes"
|
||||
|
||||
flow1 = ConditionalStartFlow(persistence=persistence)
|
||||
flow1.kickoff()
|
||||
flow_id = flow1.state["id"]
|
||||
|
||||
assert execution_log == ["init", "router", "branch_a_start", "branch_a_process"]
|
||||
|
||||
flow2 = ConditionalStartFlow(persistence=persistence)
|
||||
flow2._completed_methods = {"init", "route_to_branch", "branch_a_start"}
|
||||
execution_log.clear()
|
||||
|
||||
flow2.kickoff(inputs={"id": flow_id})
|
||||
|
||||
assert execution_log == ["branch_a_process"]
|
||||
|
||||
|
||||
def test_cyclic_flow_with_conditional_start():
|
||||
"""Test that cyclic flows work properly with conditional start methods."""
|
||||
execution_log = []
|
||||
|
||||
class CyclicConditionalFlow(Flow[Dict[str, str]]):
|
||||
iteration = 0
|
||||
|
||||
@start()
|
||||
def initial(self):
|
||||
execution_log.append("initial")
|
||||
return "init_done"
|
||||
|
||||
@router(initial)
|
||||
def route_to_cycle(self):
|
||||
execution_log.append("router_initial")
|
||||
return "loop"
|
||||
|
||||
@start("loop")
|
||||
def cycle_entry(self):
|
||||
execution_log.append(f"cycle_{self.iteration}")
|
||||
self.iteration += 1
|
||||
|
||||
@router(cycle_entry)
|
||||
def cycle_router(self):
|
||||
execution_log.append(f"router_{self.iteration - 1}")
|
||||
if self.iteration < 3:
|
||||
return "loop"
|
||||
return "exit"
|
||||
|
||||
flow = CyclicConditionalFlow()
|
||||
flow.kickoff()
|
||||
|
||||
expected = [
|
||||
"initial",
|
||||
"router_initial",
|
||||
"cycle_0",
|
||||
"router_0",
|
||||
"cycle_1",
|
||||
"router_1",
|
||||
"cycle_2",
|
||||
"router_2",
|
||||
]
|
||||
|
||||
assert execution_log == expected
|
||||
@@ -318,17 +318,11 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_is_called_using_string():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -346,26 +340,17 @@ def test_guardrail_is_called_using_string():
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
|
||||
assert len(guardrail_events['started']) == 2
|
||||
assert len(guardrail_events['completed']) == 2
|
||||
assert not guardrail_events['completed'][0].success
|
||||
assert guardrail_events['completed'][1].success
|
||||
assert "Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players" in result.raw
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_is_called_using_callable():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -383,22 +368,16 @@ def test_guardrail_is_called_using_callable():
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert len(guardrail_events['started']) == 1
|
||||
assert len(guardrail_events['completed']) == 1
|
||||
assert guardrail_events['completed'][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_reached_attempt_limit():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -411,23 +390,18 @@ def test_guardrail_reached_attempt_limit():
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail=lambda output: (False, "You are not allowed to include Brazilian players"),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
):
|
||||
with pytest.raises(Exception, match="Agent's guardrail failed validation after 2 retries"):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
assert len(guardrail_events['started']) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events['completed']) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events['completed'][0].success
|
||||
assert not guardrail_events['completed'][1].success
|
||||
assert not guardrail_events['completed'][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -440,35 +414,22 @@ def test_agent_output_when_guardrail_returns_base_model():
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
True,
|
||||
Player(name="Lionel Messi", country="Argentina"),
|
||||
),
|
||||
guardrail=lambda output: (True, Player(name="Lionel Messi", country="Argentina")),
|
||||
)
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert result.pydantic == Player(name="Lionel Messi", country="Argentina")
|
||||
|
||||
|
||||
def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
"""Test that CustomLLM (inheriting from BaseLLM) works with guardrails."""
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
def __init__(self, response: str = "Custom response"):
|
||||
super().__init__(model="custom-model")
|
||||
self.response = response
|
||||
self.call_count = 0
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages,
|
||||
tools=None,
|
||||
callbacks=None,
|
||||
available_functions=None,
|
||||
from_task=None,
|
||||
from_agent=None,
|
||||
) -> str:
|
||||
def call(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None) -> str:
|
||||
self.call_count += 1
|
||||
|
||||
if "valid" in str(messages) and "feedback" in str(messages):
|
||||
@@ -495,7 +456,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
goal="Analyze soccer players",
|
||||
backstory="You analyze soccer players and their performance.",
|
||||
llm=custom_llm,
|
||||
guardrail="Only include Brazilian players",
|
||||
guardrail="Only include Brazilian players"
|
||||
)
|
||||
|
||||
result = agent.kickoff("Tell me about the best soccer players")
|
||||
@@ -513,7 +474,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=custom_llm2,
|
||||
guardrail=test_guardrail,
|
||||
guardrail=test_guardrail
|
||||
)
|
||||
|
||||
result2 = agent2.kickoff("Test message")
|
||||
@@ -523,12 +484,12 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_lite_agent_with_invalid_llm():
|
||||
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
||||
with patch("crewai.lite_agent.create_llm", return_value=None):
|
||||
with patch('crewai.lite_agent.create_llm', return_value=None):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
LiteAgent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="invalid-model",
|
||||
llm="invalid-model"
|
||||
)
|
||||
assert "Expected LLM instance of type BaseLLM" in str(exc_info.value)
|
||||
assert "Expected LLM instance of type BaseLLM" in str(exc_info.value)
|
||||
@@ -76,29 +76,11 @@ def base_task(base_agent):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_event_listener_singleton():
|
||||
"""Reset EventListener singleton for clean test state."""
|
||||
original_instance = EventListener._instance
|
||||
original_initialized = (
|
||||
getattr(EventListener._instance, "_initialized", False)
|
||||
if EventListener._instance
|
||||
else False
|
||||
)
|
||||
|
||||
EventListener._instance = None
|
||||
|
||||
yield
|
||||
|
||||
EventListener._instance = original_instance
|
||||
if original_instance and original_initialized:
|
||||
EventListener._instance._initialized = original_initialized
|
||||
event_listener = EventListener()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_start_kickoff_event(
|
||||
base_agent, base_task, reset_event_listener_singleton
|
||||
):
|
||||
def test_crew_emits_start_kickoff_event(base_agent, base_task):
|
||||
received_events = []
|
||||
mock_span = Mock()
|
||||
|
||||
@@ -106,23 +88,18 @@ def test_crew_emits_start_kickoff_event(
|
||||
def handle_crew_start(source, event):
|
||||
received_events.append(event)
|
||||
|
||||
mock_telemetry = Mock()
|
||||
mock_telemetry.crew_execution_span = Mock(return_value=mock_span)
|
||||
mock_telemetry.end_crew = Mock(return_value=mock_span)
|
||||
mock_telemetry.set_tracer = Mock()
|
||||
mock_telemetry.task_started = Mock(return_value=mock_span)
|
||||
mock_telemetry.task_ended = Mock(return_value=mock_span)
|
||||
|
||||
# Patch the Telemetry class to return our mock
|
||||
with patch(
|
||||
"crewai.utilities.events.event_listener.Telemetry", return_value=mock_telemetry
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
with (
|
||||
patch.object(
|
||||
event_listener._telemetry, "crew_execution_span", return_value=mock_span
|
||||
) as mock_crew_execution_span,
|
||||
patch.object(
|
||||
event_listener._telemetry, "end_crew", return_value=mock_span
|
||||
) as mock_crew_ended,
|
||||
):
|
||||
# Now when Crew creates EventListener, it will use our mocked telemetry
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
mock_telemetry.crew_execution_span.assert_called_once_with(crew, None)
|
||||
mock_telemetry.end_crew.assert_called_once_with(crew, "hi")
|
||||
mock_crew_execution_span.assert_called_once_with(crew, None)
|
||||
mock_crew_ended.assert_called_once_with(crew, "hi")
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
@@ -151,6 +128,7 @@ def test_crew_emits_end_kickoff_event(base_agent, base_task):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
received_events = []
|
||||
mock_span = Mock()
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def handle_crew_end(source, event):
|
||||
@@ -165,8 +143,21 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
received_events.append(event)
|
||||
|
||||
eval_llm = LLM(model="gpt-4o-mini")
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.test(n_iterations=1, eval_llm=eval_llm)
|
||||
with (
|
||||
patch.object(
|
||||
event_listener._telemetry, "test_execution_span", return_value=mock_span
|
||||
) as mock_crew_execution_span,
|
||||
):
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.test(n_iterations=1, eval_llm=eval_llm)
|
||||
|
||||
# Verify the call was made with correct argument types and values
|
||||
assert mock_crew_execution_span.call_count == 1
|
||||
args = mock_crew_execution_span.call_args[0]
|
||||
assert isinstance(args[0], Crew)
|
||||
assert args[1] == 1
|
||||
assert args[2] is None
|
||||
assert args[3] == eval_llm
|
||||
|
||||
assert len(received_events) == 3
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
@@ -223,9 +214,7 @@ def test_crew_emits_start_task_event(base_agent, base_task):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_end_task_event(
|
||||
base_agent, base_task, reset_event_listener_singleton
|
||||
):
|
||||
def test_crew_emits_end_task_event(base_agent, base_task):
|
||||
received_events = []
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
@@ -233,22 +222,19 @@ def test_crew_emits_end_task_event(
|
||||
received_events.append(event)
|
||||
|
||||
mock_span = Mock()
|
||||
|
||||
mock_telemetry = Mock()
|
||||
mock_telemetry.task_started = Mock(return_value=mock_span)
|
||||
mock_telemetry.task_ended = Mock(return_value=mock_span)
|
||||
mock_telemetry.set_tracer = Mock()
|
||||
mock_telemetry.crew_execution_span = Mock()
|
||||
mock_telemetry.end_crew = Mock()
|
||||
|
||||
with patch(
|
||||
"crewai.utilities.events.event_listener.Telemetry", return_value=mock_telemetry
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
with (
|
||||
patch.object(
|
||||
event_listener._telemetry, "task_started", return_value=mock_span
|
||||
) as mock_task_started,
|
||||
patch.object(
|
||||
event_listener._telemetry, "task_ended", return_value=mock_span
|
||||
) as mock_task_ended,
|
||||
):
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task)
|
||||
mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew)
|
||||
mock_task_started.assert_called_once_with(crew=crew, task=base_task)
|
||||
mock_task_ended.assert_called_once_with(mock_span, base_task, crew)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
@@ -437,7 +423,7 @@ def test_tools_emits_error_events():
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
|
||||
|
||||
def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
def test_flow_emits_start_event():
|
||||
received_events = []
|
||||
mock_span = Mock()
|
||||
|
||||
@@ -450,21 +436,15 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
mock_telemetry = Mock()
|
||||
mock_telemetry.flow_execution_span = Mock(return_value=mock_span)
|
||||
mock_telemetry.flow_creation_span = Mock()
|
||||
mock_telemetry.set_tracer = Mock()
|
||||
|
||||
with patch(
|
||||
"crewai.utilities.events.event_listener.Telemetry", return_value=mock_telemetry
|
||||
with (
|
||||
patch.object(
|
||||
event_listener._telemetry, "flow_execution_span", return_value=mock_span
|
||||
) as mock_flow_execution_span,
|
||||
):
|
||||
# Force creation of EventListener singleton with mocked telemetry
|
||||
_ = EventListener()
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
|
||||
mock_flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_started"
|
||||
@@ -592,6 +572,7 @@ def test_multiple_handlers_for_same_event(base_agent, base_task):
|
||||
|
||||
def test_flow_emits_created_event():
|
||||
received_events = []
|
||||
mock_span = Mock()
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def handle_flow_created(source, event):
|
||||
@@ -602,8 +583,15 @@ def test_flow_emits_created_event():
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
with (
|
||||
patch.object(
|
||||
event_listener._telemetry, "flow_creation_span", return_value=mock_span
|
||||
) as mock_flow_creation_span,
|
||||
):
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
mock_flow_creation_span.assert_called_once_with("TestFlow")
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,9 +8,7 @@ from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a unique file name for each test to avoid race conditions in parallel test execution
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_data_{unique_id}.pkl"
|
||||
self.file_name = "test_data.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
@@ -40,8 +37,6 @@ class TestPickleHandler(unittest.TestCase):
|
||||
def test_load_corrupted_file(self):
|
||||
with open(self.file_path, "wb") as file:
|
||||
file.write(b"corrupted data")
|
||||
file.flush()
|
||||
os.fsync(file.fileno()) # Ensure data is written to disk
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
self.handler.load()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -7,13 +6,10 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
class InternalCrewTrainingHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_file = tempfile.NamedTemporaryFile(suffix=".pkl", delete=False)
|
||||
self.temp_file.close()
|
||||
self.handler = CrewTrainingHandler(self.temp_file.name)
|
||||
self.handler = CrewTrainingHandler("trained_data.pkl")
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.temp_file.name):
|
||||
os.remove(self.temp_file.name)
|
||||
os.remove("trained_data.pkl")
|
||||
del self.handler
|
||||
|
||||
def test_save_trained_data(self):
|
||||
@@ -26,22 +22,13 @@ class InternalCrewTrainingHandler(unittest.TestCase):
|
||||
assert data[agent_id] == trained_data
|
||||
|
||||
def test_append_existing_agent(self):
|
||||
agent_id = "agent1"
|
||||
initial_iteration = 0
|
||||
initial_data = {"param1": 1, "param2": 2}
|
||||
|
||||
self.handler.append(initial_iteration, agent_id, initial_data)
|
||||
|
||||
train_iteration = 1
|
||||
agent_id = "agent1"
|
||||
new_data = {"param3": 3, "param4": 4}
|
||||
self.handler.append(train_iteration, agent_id, new_data)
|
||||
|
||||
# Assert that the new data is appended correctly to the existing agent
|
||||
data = self.handler.load()
|
||||
assert agent_id in data
|
||||
assert initial_iteration in data[agent_id]
|
||||
assert train_iteration in data[agent_id]
|
||||
assert data[agent_id][initial_iteration] == initial_data
|
||||
assert data[agent_id][train_iteration] == new_data
|
||||
|
||||
def test_append_new_agent(self):
|
||||
|
||||
Reference in New Issue
Block a user