mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 13:48:31 +00:00
Compare commits
8 Commits
gl/feat/wo
...
lorenze/tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a8889ce61 | ||
|
|
d865a49f5a | ||
|
|
677fe9032c | ||
|
|
6e8c1f332f | ||
|
|
abe170cdc2 | ||
|
|
51767f2e15 | ||
|
|
dc41a0d13b | ||
|
|
6d02b64674 |
2
.github/workflows/build-uv-cache.yml
vendored
2
.github/workflows/build-uv-cache.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
- name: Install dependencies and populate cache
|
||||
run: |
|
||||
echo "Building global UV cache for Python ${{ matrix.python-version }}..."
|
||||
uv sync --all-groups --all-extras
|
||||
uv sync --all-groups --all-extras --no-install-project
|
||||
echo "Cache populated successfully"
|
||||
|
||||
- name: Save uv caches
|
||||
|
||||
102
.github/workflows/codeql.yml
vendored
102
.github/workflows/codeql.yml
vendored
@@ -1,102 +0,0 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL Advanced"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze (${{ matrix.language }})
|
||||
# Runner size impacts CodeQL analysis time. To learn more, please see:
|
||||
# - https://gh.io/recommended-hardware-resources-for-running-codeql
|
||||
# - https://gh.io/supported-runners-and-hardware-resources
|
||||
# - https://gh.io/using-larger-runners (GitHub.com only)
|
||||
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
|
||||
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
# only required for workflows in private repositories
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- language: actions
|
||||
build-mode: none
|
||||
- language: python
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Add any setup steps before running the `github/codeql-action/init` action.
|
||||
# This includes steps like installing compilers or runtimes (`actions/setup-node`
|
||||
# or others). This is typically only required for manual builds.
|
||||
# - name: Setup runtime (example)
|
||||
# uses: actions/setup-example@v1
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
# If the analyze step fails for one of the languages you are analyzing with
|
||||
# "We were unable to automatically build your code", modify the matrix above
|
||||
# to set the build mode to "manual" for that language. Then modify this step
|
||||
# to build your code.
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
- if: matrix.build-mode == 'manual'
|
||||
shell: bash
|
||||
run: |
|
||||
echo 'If you are using a "manual" build mode for one or more of the' \
|
||||
'languages you are analyzing, replace this with the commands to build' \
|
||||
'your code, for example:'
|
||||
echo ' make bootstrap'
|
||||
echo ' make release'
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
2
.github/workflows/linter.yml
vendored
2
.github/workflows/linter.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-packages --all-extras --no-install-project
|
||||
run: uv sync --all-groups --all-extras --no-install-project
|
||||
|
||||
- name: Get Changed Python Files
|
||||
id: changed-files
|
||||
|
||||
63
.github/workflows/tests.yml
vendored
63
.github/workflows/tests.yml
vendored
@@ -25,17 +25,17 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
# - name: Restore global uv cache
|
||||
# id: cache-restore
|
||||
# uses: actions/cache/restore@v4
|
||||
# with:
|
||||
# path: |
|
||||
# ~/.cache/uv
|
||||
# ~/.local/share/uv
|
||||
# .venv
|
||||
# key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
# restore-keys: |
|
||||
# uv-main-py${{ matrix.python-version }}-
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py${{ matrix.python-version }}-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
@@ -45,24 +45,24 @@ jobs:
|
||||
enable-cache: false
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-packages --all-extras
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
# - name: Restore test durations
|
||||
# uses: actions/cache/restore@v4
|
||||
# with:
|
||||
# path: .test_durations_py*
|
||||
# key: test-durations-py${{ matrix.python-version }}
|
||||
- name: Restore test durations
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: .test_durations_py*
|
||||
key: test-durations-py${{ matrix.python-version }}
|
||||
|
||||
- name: Run tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||
DURATION_FILE=".test_durations_py${PYTHON_VERSION_SAFE}"
|
||||
|
||||
|
||||
# Temporarily always skip cached durations to fix test splitting
|
||||
# When durations don't match, pytest-split runs duplicate tests instead of splitting
|
||||
echo "Using even test splitting (duration cache disabled until fix merged)"
|
||||
DURATIONS_ARG=""
|
||||
|
||||
|
||||
# Original logic (disabled temporarily):
|
||||
# if [ ! -f "$DURATION_FILE" ]; then
|
||||
# echo "No cached durations found, tests will be split evenly"
|
||||
@@ -74,8 +74,8 @@ jobs:
|
||||
# echo "No test changes detected, using cached test durations for optimal splitting"
|
||||
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
|
||||
# fi
|
||||
|
||||
uv run pytest lib/crewai \
|
||||
|
||||
uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
@@ -84,15 +84,14 @@ jobs:
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3 \
|
||||
-m "not requires_local_services"
|
||||
--maxfail=3
|
||||
|
||||
# - name: Save uv caches
|
||||
# if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
# uses: actions/cache/save@v4
|
||||
# with:
|
||||
# path: |
|
||||
# ~/.cache/uv
|
||||
# ~/.local/share/uv
|
||||
# .venv
|
||||
# key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
|
||||
2
.github/workflows/type-checker.yml
vendored
2
.github/workflows/type-checker.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-packages --all-extras
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Get changed Python files
|
||||
id: changed-files
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
dist/
|
||||
lib/
|
||||
.env
|
||||
assets/*
|
||||
.idea
|
||||
|
||||
@@ -6,19 +6,14 @@ repos:
|
||||
entry: uv run ruff check
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^lib/crewai/src/
|
||||
exclude: ^lib/crewai/
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
entry: uv run ruff format
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^lib/crewai/src/
|
||||
exclude: ^lib/crewai/
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: uv run mypy
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^lib/crewai/src/
|
||||
exclude: ^lib/crewai/
|
||||
exclude: ^tests/
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
@@ -5,82 +5,6 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Updated pyproject templates to use the right version
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Series of minor fixes and linter improvements
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 19, 2025">
|
||||
## v0.193.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Core Improvements & Fixes
|
||||
|
||||
- Fixed handling of the `model` parameter during OpenAI adapter initialization
|
||||
- Resolved test duration cache issues in CI workflows
|
||||
- Fixed flaky test related to repeated tool usage by agents
|
||||
- Added missing event exports to `__init__.py` for consistent module behavior
|
||||
- Dropped message storage from metadata in Mem0 to reduce bloat
|
||||
- Fixed L2 distance metric support for backward compatibility in vector search
|
||||
|
||||
## New Features & Enhancements
|
||||
|
||||
- Introduced thread-safe platform context management
|
||||
- Added test duration caching for optimized `pytest-split` runs
|
||||
- Added ephemeral trace improvements for better trace control
|
||||
- Made search parameters for RAG, knowledge, and memory fully configurable
|
||||
- Enabled ChromaDB to use OpenAI API for embedding functions
|
||||
- Added deeper observability tools for user-level insights
|
||||
- Unified RAG storage system with instance-specific client support
|
||||
|
||||
## Documentation & Guides
|
||||
|
||||
- Updated `RagTool` references to reflect CrewAI native RAG implementation
|
||||
- Improved internal docs for `langgraph` and `openai` agent adapters with type annotations and docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 11, 2025">
|
||||
## v0.186.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Fixed version not being found and silently failing reversion
|
||||
- Bumped CrewAI version to 0.186.1 and updated dependencies in the CLI
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 10, 2025">
|
||||
## v0.186.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Refer to the GitHub release notes for detailed changes
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 04, 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -404,10 +404,6 @@ crewai config reset
|
||||
After resetting configuration, re-run `crewai login` to authenticate again.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
CrewAI CLI handles authentication to the Tool Repository automatically when adding packages to your project. Just append `crewai` before any `uv` command to use it. E.g. `crewai uv add requests`. For more information, see [Tool Repository](https://docs.crewai.com/enterprise/features/tool-repository) docs.
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
Configuration settings are stored in `~/.config/crewai/settings.json`. Some settings like organization name and UUID are read-only and managed through authentication and organization commands. Tool repository related settings are hidden and cannot be set directly by users.
|
||||
</Note>
|
||||
|
||||
@@ -52,36 +52,6 @@ researcher = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
## Adding other packages after installing a tool
|
||||
|
||||
After installing a tool from the CrewAI Enterprise Tool Repository, you need to use the `crewai uv` command to add other packages to your project.
|
||||
Using pure `uv` commands will fail due to authentication to tool repository being handled by the CLI. By using the `crewai uv` command, you can add other packages to your project without having to worry about authentication.
|
||||
Any `uv` command can be used with the `crewai uv` command, making it a powerful tool for managing your project's dependencies without the hassle of managing authentication through environment variables or other methods.
|
||||
|
||||
Say that you have installed a custom tool from the CrewAI Enterprise Tool Repository called "my-tool":
|
||||
|
||||
```bash
|
||||
crewai tool install my-tool
|
||||
```
|
||||
|
||||
And now you want to add another package to your project, you can use the following command:
|
||||
|
||||
```bash
|
||||
crewai uv add requests
|
||||
```
|
||||
|
||||
Other commands like `uv sync` or `uv remove` can also be used with the `crewai uv` command:
|
||||
|
||||
```bash
|
||||
crewai uv sync
|
||||
```
|
||||
|
||||
```bash
|
||||
crewai uv remove requests
|
||||
```
|
||||
|
||||
This will add the package to your project and update `pyproject.toml` accordingly.
|
||||
|
||||
## Creating and Publishing Tools
|
||||
|
||||
To create a new tool project:
|
||||
|
||||
@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣♂️
|
||||
<Step title="Navigate to your new crew project">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -9,7 +9,7 @@ mode: "wide"
|
||||
|
||||
## Description
|
||||
|
||||
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through CrewAI's native RAG system.
|
||||
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through EmbedChain.
|
||||
It provides a dynamic knowledge base that can be queried to retrieve relevant information from various data sources.
|
||||
This tool is particularly useful for applications that require access to a vast array of information and need to provide contextually relevant answers.
|
||||
|
||||
@@ -76,8 +76,8 @@ The `RagTool` can be used with a wide variety of data sources, including:
|
||||
The `RagTool` accepts the following parameters:
|
||||
|
||||
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, an EmbedchainAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying EmbedChain App.
|
||||
|
||||
## Adding Content
|
||||
|
||||
@@ -130,23 +130,44 @@ from crewai_tools import RagTool
|
||||
|
||||
# Create a RAG tool with custom configuration
|
||||
config = {
|
||||
"vectordb": {
|
||||
"provider": "qdrant",
|
||||
"app": {
|
||||
"name": "custom_app",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"collection_name": "my-collection"
|
||||
"model": "gpt-4",
|
||||
}
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small"
|
||||
"model": "text-embedding-ada-002"
|
||||
}
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "elasticsearch",
|
||||
"config": {
|
||||
"collection_name": "my-collection",
|
||||
"cloud_id": "deployment-name:xxxx",
|
||||
"api_key": "your-key",
|
||||
"verify_certs": False
|
||||
}
|
||||
},
|
||||
"chunker": {
|
||||
"chunk_size": 400,
|
||||
"chunk_overlap": 100,
|
||||
"length_function": "len",
|
||||
"min_chunk_size": 0
|
||||
}
|
||||
}
|
||||
|
||||
rag_tool = RagTool(config=config, summarize=True)
|
||||
```
|
||||
|
||||
The internal RAG tool utilizes the Embedchain adapter, allowing you to pass any configuration options that are supported by Embedchain.
|
||||
You can refer to the [Embedchain documentation](https://docs.embedchain.ai/components/introduction) for details.
|
||||
Make sure to review the configuration options available in the .yaml file.
|
||||
|
||||
## Conclusion
|
||||
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.
|
||||
|
||||
@@ -5,82 +5,6 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 올바른 버전을 사용하도록 pyproject 템플릿 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 일련의 사소한 수정 및 린터 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 19일">
|
||||
## v0.193.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## 핵심 개선 사항 및 수정 사항
|
||||
|
||||
- OpenAI 어댑터 초기화 중 `model` 매개변수 처리 수정
|
||||
- CI 워크플로에서 테스트 소요 시간 캐시 문제 해결
|
||||
- 에이전트의 반복 도구 사용과 관련된 불안정한 테스트 수정
|
||||
- 일관된 모듈 동작을 위해 누락된 이벤트 내보내기를 `__init__.py`에 추가
|
||||
- 메타데이터 부하를 줄이기 위해 Mem0에서 메시지 저장 제거
|
||||
- 벡터 검색의 하위 호환성을 위해 L2 거리 메트릭 지원 수정
|
||||
|
||||
## 새로운 기능 및 향상 사항
|
||||
|
||||
- 스레드 안전한 플랫폼 컨텍스트 관리 도입
|
||||
- `pytest-split` 실행 최적화를 위한 테스트 소요 시간 캐싱 추가
|
||||
- 더 나은 추적 제어를 위한 일시적(trace) 개선
|
||||
- RAG, 지식, 메모리 검색 매개변수를 완전 구성 가능하게 변경
|
||||
- ChromaDB가 임베딩 함수에 OpenAI API를 사용할 수 있도록 지원
|
||||
- 사용자 수준 인사이트를 위한 심화된 관찰 가능성 도구 추가
|
||||
- 인스턴스별 클라이언트를 지원하는 통합 RAG 스토리지 시스템
|
||||
|
||||
## 문서 및 가이드
|
||||
|
||||
- CrewAI 네이티브 RAG 구현을 반영하도록 `RagTool` 참조 업데이트
|
||||
- 타입 주석과 도크스트링을 포함해 `langgraph` 및 `openai` 에이전트 어댑터 내부 문서 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 11일">
|
||||
## v0.186.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 버전을 찾지 못해 조용히 되돌리는(reversion) 문제 수정
|
||||
- CLI에서 CrewAI 버전을 0.186.1로 올리고 의존성 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 10일">
|
||||
## v0.186.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 자세한 변경 사항은 GitHub 릴리스 노트를 참조하세요
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 4일">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ mode: "wide"
|
||||
<Step title="새로운 crew 프로젝트로 이동하기">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -5,82 +5,6 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Atualizados templates do pyproject para usar a versão correta
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Série de pequenas correções e melhorias de linter
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="19 set 2025">
|
||||
## v0.193.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Melhorias e Correções Principais
|
||||
|
||||
- Corrigido manuseio do parâmetro `model` durante a inicialização do adaptador OpenAI
|
||||
- Resolvidos problemas de cache da duração de testes nos fluxos de CI
|
||||
- Corrigido teste instável relacionado ao uso repetido de ferramentas pelos agentes
|
||||
- Adicionadas exportações de eventos ausentes no `__init__.py` para comportamento consistente do módulo
|
||||
- Removido armazenamento de mensagem dos metadados no Mem0 para reduzir inchaço
|
||||
- Corrigido suporte à métrica de distância L2 para compatibilidade retroativa na busca vetorial
|
||||
|
||||
## Novos Recursos e Melhorias
|
||||
|
||||
- Introduzida gestão de contexto de plataforma com segurança de threads
|
||||
- Adicionado cache da duração de testes para execuções otimizadas do `pytest-split`
|
||||
- Melhorias de traces efêmeros para melhor controle de rastreamento
|
||||
- Parâmetros de busca para RAG, conhecimento e memória totalmente configuráveis
|
||||
- Habilitado ChromaDB para usar a OpenAI API para funções de embedding
|
||||
- Adicionadas ferramentas de observabilidade mais profundas para insights ao nível do usuário
|
||||
- Sistema de armazenamento RAG unificado com suporte a cliente específico por instância
|
||||
|
||||
## Documentação e Guias
|
||||
|
||||
- Atualizadas referências do `RagTool` para refletir a implementação nativa de RAG do CrewAI
|
||||
- Melhorada documentação interna para adaptadores de agente `langgraph` e `openai` com anotações de tipo e docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 set 2025">
|
||||
## v0.186.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Corrigida falha silenciosa de reversão quando a versão não era encontrada
|
||||
- Versão do CrewAI atualizada para 0.186.1 e dependências do CLI atualizadas
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="10 set 2025">
|
||||
## v0.186.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Consulte as notas de lançamento no GitHub para detalhes completos
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="04 set 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣♂️
|
||||
<Step title="Navegue até o novo projeto da sua tripulação">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
3.13
|
||||
@@ -1,124 +0,0 @@
|
||||
[project]
|
||||
name = "crewai"
|
||||
dynamic = ["version"]
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Greyson Lalonde", email = "greyson.r.lalonde@gmail.com" }
|
||||
]
|
||||
keywords = [
|
||||
"crewai",
|
||||
"ai",
|
||||
"agents",
|
||||
"framework",
|
||||
"orchestration",
|
||||
"llm",
|
||||
"core",
|
||||
"typed",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"crewai",
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.74.9",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
"appdirs>=1.4.4",
|
||||
"jsonref>=1.1.0",
|
||||
"json-repair==0.25.2",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"chromadb~=1.1.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"uv>=0.4.25",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
pdfplumber = [
|
||||
"pdfplumber>=0.11.4",
|
||||
]
|
||||
pandas = [
|
||||
"pandas>=2.2.3",
|
||||
]
|
||||
openpyxl = [
|
||||
"openpyxl>=3.1.5",
|
||||
]
|
||||
mem0 = ["mem0ai>=0.1.94"]
|
||||
docling = [
|
||||
"docling>=2.12.0",
|
||||
]
|
||||
aisuite = [
|
||||
"aisuite>=0.1.10",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://crewai.com"
|
||||
Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "strict"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai/__init__.py"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/crewai"]
|
||||
@@ -1,12 +0,0 @@
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import AgentAction, AgentFinish, OutputParserError, parse
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
__all__ = [
|
||||
"AgentAction",
|
||||
"AgentFinish",
|
||||
"CacheHandler",
|
||||
"OutputParserError",
|
||||
"ToolsHandler",
|
||||
"parse",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""CrewAI events system for monitoring and extending agent behavior.
|
||||
|
||||
This module provides the event infrastructure that allows users to:
|
||||
- Monitor agent, task, and crew execution
|
||||
- Track memory operations and performance
|
||||
- Build custom logging and analytics
|
||||
- Extend CrewAI with custom event handlers
|
||||
"""
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
ReasoningEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskEvaluationEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationCompletedEvent",
|
||||
"AgentEvaluationFailedEvent",
|
||||
"AgentEvaluationStartedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"AgentExecutionErrorEvent",
|
||||
"AgentExecutionStartedEvent",
|
||||
"AgentLogsExecutionEvent",
|
||||
"AgentLogsStartedEvent",
|
||||
"AgentReasoningCompletedEvent",
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewTestCompletedEvent",
|
||||
"CrewTestFailedEvent",
|
||||
"CrewTestResultEvent",
|
||||
"CrewTestStartedEvent",
|
||||
"CrewTrainCompletedEvent",
|
||||
"CrewTrainFailedEvent",
|
||||
"CrewTrainStartedEvent",
|
||||
"FlowCreatedEvent",
|
||||
"FlowEvent",
|
||||
"FlowFinishedEvent",
|
||||
"FlowPlotEvent",
|
||||
"FlowStartedEvent",
|
||||
"KnowledgeQueryCompletedEvent",
|
||||
"KnowledgeQueryFailedEvent",
|
||||
"KnowledgeQueryStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeSearchQueryFailedEvent",
|
||||
"LLMCallCompletedEvent",
|
||||
"LLMCallFailedEvent",
|
||||
"LLMCallStartedEvent",
|
||||
"LLMGuardrailCompletedEvent",
|
||||
"LLMGuardrailStartedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
"MemoryRetrievalCompletedEvent",
|
||||
"MemoryRetrievalStartedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveFailedEvent",
|
||||
"MemorySaveStartedEvent",
|
||||
"MethodExecutionFailedEvent",
|
||||
"MethodExecutionFinishedEvent",
|
||||
"MethodExecutionStartedEvent",
|
||||
"ReasoningEvent",
|
||||
"TaskCompletedEvent",
|
||||
"TaskEvaluationEvent",
|
||||
"TaskFailedEvent",
|
||||
"TaskStartedEvent",
|
||||
"ToolExecutionErrorEvent",
|
||||
"ToolSelectionErrorEvent",
|
||||
"ToolUsageErrorEvent",
|
||||
"ToolUsageEvent",
|
||||
"ToolUsageFinishedEvent",
|
||||
"ToolUsageStartedEvent",
|
||||
"ToolValidateInputErrorEvent",
|
||||
"crewai_event_bus",
|
||||
]
|
||||
@@ -1,229 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
mark_first_execution_completed,
|
||||
prompt_user_for_trace_viewing,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_or_create_env_file():
|
||||
"""Update or create .env file with CREWAI_TRACING_ENABLED=true."""
|
||||
env_path = Path(".env")
|
||||
env_content = ""
|
||||
variable_name = "CREWAI_TRACING_ENABLED"
|
||||
variable_value = "true"
|
||||
|
||||
# Read existing content if file exists
|
||||
if env_path.exists():
|
||||
with open(env_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Check if CREWAI_TRACING_ENABLED is already set
|
||||
lines = env_content.splitlines()
|
||||
variable_exists = False
|
||||
updated_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith(f"{variable_name}="):
|
||||
# Update existing variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
variable_exists = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
# Add variable if it doesn't exist
|
||||
if not variable_exists:
|
||||
if updated_lines and not updated_lines[-1].strip():
|
||||
# If last line is empty, replace it
|
||||
updated_lines[-1] = f"{variable_name}={variable_value}"
|
||||
else:
|
||||
# Add new line and then the variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
|
||||
# Write updated content
|
||||
with open(env_path, "w") as f:
|
||||
f.write("\n".join(updated_lines))
|
||||
if updated_lines: # Add final newline if there's content
|
||||
f.write("\n")
|
||||
|
||||
|
||||
class FirstTimeTraceHandler:
|
||||
"""Handles the first-time user trace collection and display flow."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_first_time: bool = False
|
||||
self.collected_events: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.ephemeral_url: str | None = None
|
||||
self.batch_manager: TraceBatchManager | None = None
|
||||
|
||||
def initialize_for_first_time_user(self) -> bool:
|
||||
"""Check if this is first time and initialize collection."""
|
||||
self.is_first_time = should_auto_collect_first_time_traces()
|
||||
return self.is_first_time
|
||||
|
||||
def set_batch_manager(self, batch_manager: TraceBatchManager):
|
||||
"""Set reference to batch manager for sending events."""
|
||||
self.batch_manager = batch_manager
|
||||
|
||||
def mark_events_collected(self):
|
||||
"""Mark that events have been collected during execution."""
|
||||
self.collected_events = True
|
||||
|
||||
def handle_execution_completion(self):
|
||||
"""Handle the completion flow as shown in your diagram."""
|
||||
if not self.is_first_time or not self.collected_events:
|
||||
return
|
||||
|
||||
try:
|
||||
user_wants_traces = prompt_user_for_trace_viewing(timeout_seconds=20)
|
||||
|
||||
if user_wants_traces:
|
||||
self._initialize_backend_and_send_events()
|
||||
|
||||
# Enable tracing for future runs by updating .env file
|
||||
try:
|
||||
_update_or_create_env_file()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if self.ephemeral_url:
|
||||
self._display_ephemeral_trace_link()
|
||||
|
||||
mark_first_execution_completed()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||
mark_first_execution_completed()
|
||||
|
||||
def _initialize_backend_and_send_events(self):
|
||||
"""Initialize backend batch and send collected events."""
|
||||
if not self.batch_manager:
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.batch_manager.backend_initialized:
|
||||
original_metadata = (
|
||||
self.batch_manager.current_batch.execution_metadata
|
||||
if self.batch_manager.current_batch
|
||||
else {}
|
||||
)
|
||||
|
||||
user_context = {
|
||||
"privacy_level": "standard",
|
||||
"user_id": "first_time_user",
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"trace_id": self.batch_manager.trace_batch_id,
|
||||
}
|
||||
|
||||
execution_metadata = {
|
||||
"execution_type": original_metadata.get("execution_type", "crew"),
|
||||
"crew_name": original_metadata.get(
|
||||
"crew_name", "First Time Execution"
|
||||
),
|
||||
"flow_name": original_metadata.get("flow_name"),
|
||||
"agent_count": original_metadata.get("agent_count", 1),
|
||||
"task_count": original_metadata.get("task_count", 1),
|
||||
"crewai_version": original_metadata.get("crewai_version"),
|
||||
}
|
||||
|
||||
self.batch_manager._initialize_backend_batch(
|
||||
user_context=user_context,
|
||||
execution_metadata=execution_metadata,
|
||||
use_ephemeral=True,
|
||||
)
|
||||
self.batch_manager.backend_initialized = True
|
||||
|
||||
if self.batch_manager.event_buffer:
|
||||
self.batch_manager._send_events_to_backend()
|
||||
|
||||
self.batch_manager.finalize_batch()
|
||||
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
|
||||
|
||||
if not self.ephemeral_url:
|
||||
self._show_local_trace_message()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
panel_content = f"""
|
||||
🎉 Your First CrewAI Execution Trace is Ready!
|
||||
|
||||
View your execution details here:
|
||||
{self.ephemeral_url}
|
||||
|
||||
This trace shows:
|
||||
• Agent decisions and interactions
|
||||
• Task execution timeline
|
||||
• Tool usage and results
|
||||
• LLM calls and responses
|
||||
|
||||
✅ Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
You can also add tracing=True to your Crew(tracing=True) / Flow(tracing=True) for more control.
|
||||
|
||||
📝 Note: This link will expire in 24 hours.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Execution Trace Generated",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _gracefully_fail(self, error_message: str):
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self):
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
📊 Your execution traces were collected locally!
|
||||
|
||||
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
|
||||
• {len(self.batch_manager.event_buffer)} trace events
|
||||
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
|
||||
• Batch ID: {self.batch_manager.trace_batch_id}
|
||||
|
||||
Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
The traces include agent decisions, task execution, and tool usage.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Local Traces Collected",
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
@@ -1,379 +0,0 @@
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"
|
||||
|
||||
|
||||
def on_first_execution_tracing_confirmation() -> bool:
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
if is_first_execution():
|
||||
mark_first_execution_done()
|
||||
return click.confirm(
|
||||
"This is the first execution of CrewAI. Do you want to enable tracing?",
|
||||
default=True,
|
||||
show_default=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _is_test_environment() -> bool:
|
||||
"""Detect if we're running in a test environment."""
|
||||
return os.environ.get("CREWAI_TESTING", "").lower() == "true"
|
||||
|
||||
|
||||
def _get_machine_id() -> str:
|
||||
"""Stable, privacy-preserving machine fingerprint (cross-platform)."""
|
||||
parts = []
|
||||
|
||||
try:
|
||||
mac = ":".join(
|
||||
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
|
||||
)
|
||||
parts.append(mac)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
except Exception:
|
||||
sysname = "unknown"
|
||||
parts.append(sysname)
|
||||
|
||||
try:
|
||||
if sysname == "Darwin":
|
||||
try:
|
||||
res = subprocess.run(
|
||||
["/usr/sbin/system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
m = re.search(r"Hardware UUID:\s*([A-Fa-f0-9\-]+)", res.stdout)
|
||||
if m:
|
||||
parts.append(m.group(1))
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
elif sysname == "Linux":
|
||||
linux_id = _get_linux_machine_id()
|
||||
if linux_id:
|
||||
parts.append(linux_id)
|
||||
|
||||
elif sysname == "Windows":
|
||||
try:
|
||||
res = subprocess.run(
|
||||
[
|
||||
"C:\\Windows\\System32\\wbem\\wmic.exe",
|
||||
"csproduct",
|
||||
"get",
|
||||
"UUID",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
lines = [
|
||||
line.strip() for line in res.stdout.splitlines() if line.strip()
|
||||
]
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
else:
|
||||
generic_id = _get_generic_system_id()
|
||||
if generic_id:
|
||||
parts.append(generic_id)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if len(parts) <= 1:
|
||||
try:
|
||||
import socket
|
||||
|
||||
parts.append(socket.gethostname())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(getpass.getuser())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(platform.machine())
|
||||
parts.append(platform.processor())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if not parts:
|
||||
parts.append("unknown-system")
|
||||
parts.append(str(uuid.uuid4()))
|
||||
|
||||
return hashlib.sha256("".join(parts).encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_linux_machine_id() -> str | None:
|
||||
linux_id_sources = [
|
||||
"/etc/machine-id",
|
||||
"/sys/class/dmi/id/product_uuid",
|
||||
"/proc/sys/kernel/random/boot_id",
|
||||
"/sys/class/dmi/id/board_serial",
|
||||
"/sys/class/dmi/id/chassis_serial",
|
||||
]
|
||||
|
||||
for source in linux_id_sources:
|
||||
try:
|
||||
path = Path(source)
|
||||
if path.exists() and path.is_file():
|
||||
content = path.read_text().strip()
|
||||
if content and content.lower() not in [
|
||||
"unknown",
|
||||
"to be filled by o.e.m.",
|
||||
"",
|
||||
]:
|
||||
return content
|
||||
except Exception: # noqa: S112, PERF203
|
||||
continue
|
||||
|
||||
try:
|
||||
import socket
|
||||
|
||||
hostname = socket.gethostname()
|
||||
arch = platform.machine()
|
||||
if hostname and arch:
|
||||
return f"{hostname}-{arch}"
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_generic_system_id() -> str | None:
|
||||
try:
|
||||
parts = []
|
||||
|
||||
try:
|
||||
import socket
|
||||
|
||||
hostname = socket.gethostname()
|
||||
if hostname:
|
||||
parts.append(hostname)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(platform.machine())
|
||||
parts.append(platform.processor())
|
||||
parts.append(platform.architecture()[0])
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
container_id = os.environ.get(
|
||||
"HOSTNAME", os.environ.get("CONTAINER_ID", "")
|
||||
)
|
||||
if container_id:
|
||||
parts.append(container_id)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if parts:
|
||||
return "-".join(filter(None, parts))
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _user_data_file() -> Path:
|
||||
base = Path(db_storage_path())
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base / ".crewai_user.json"
|
||||
|
||||
|
||||
def _load_user_data() -> dict:
|
||||
p = _user_data_file()
|
||||
if p.exists():
|
||||
try:
|
||||
return json.loads(p.read_text())
|
||||
except (json.JSONDecodeError, OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to load user data: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict) -> None:
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return data["user_id"]
|
||||
|
||||
try:
|
||||
username = getpass.getuser()
|
||||
except Exception:
|
||||
username = "unknown"
|
||||
|
||||
seed = f"{username}|{_get_machine_id()}"
|
||||
uid = hashlib.sha256(seed.encode()).hexdigest()
|
||||
|
||||
data["user_id"] = uid
|
||||
_save_user_data(data)
|
||||
return uid
|
||||
|
||||
|
||||
def is_first_execution() -> bool:
|
||||
"""True if this is the first execution for this user."""
|
||||
data = _load_user_data()
|
||||
return not data.get("first_execution_done", False)
|
||||
|
||||
|
||||
def mark_first_execution_done() -> None:
|
||||
"""Mark that the first execution has been completed."""
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": get_user_id(),
|
||||
"machine_id": _get_machine_id(),
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
return {"serialized_data": serialized}
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
|
||||
def truncate_messages(messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
limited_messages = messages[:max_messages]
|
||||
|
||||
for msg in limited_messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
content = msg["content"]
|
||||
if len(content) > max_content_length:
|
||||
msg["content"] = content[:max_content_length] + "..."
|
||||
|
||||
return limited_messages
|
||||
|
||||
|
||||
def should_auto_collect_first_time_traces() -> bool:
|
||||
"""True if we should auto-collect traces for first-time user."""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
return is_first_execution()
|
||||
|
||||
|
||||
def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
"""
|
||||
Prompt user if they want to see their traces with timeout.
|
||||
Returns True if user wants to see traces, False otherwise.
|
||||
"""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
||||
console = Console()
|
||||
|
||||
content = Text()
|
||||
content.append("🔍 ", style="cyan bold")
|
||||
content.append(
|
||||
"Detailed execution traces are available!\n\n", style="cyan bold"
|
||||
)
|
||||
content.append("View insights including:\n", style="white")
|
||||
content.append(" • Agent decision-making process\n", style="bright_blue")
|
||||
content.append(" • Task execution flow and timing\n", style="bright_blue")
|
||||
content.append(" • Tool usage details", style="bright_blue")
|
||||
|
||||
panel = Panel(
|
||||
content,
|
||||
title="[bold cyan]Execution Traces[/bold cyan]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
|
||||
prompt_text = click.style(
|
||||
f"Would you like to view your execution traces? [y/N] ({timeout_seconds}s timeout): ",
|
||||
fg="white",
|
||||
bold=True,
|
||||
)
|
||||
click.echo(prompt_text, nl=False)
|
||||
|
||||
result = [False]
|
||||
|
||||
def get_input():
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
input_thread.start()
|
||||
input_thread.join(timeout=timeout_seconds)
|
||||
|
||||
if input_thread.is_alive():
|
||||
return False
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def mark_first_execution_completed() -> None:
|
||||
"""Mark first execution as completed (called after trace prompt)."""
|
||||
mark_first_execution_done()
|
||||
@@ -1,7 +0,0 @@
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
||||
|
||||
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]
|
||||
@@ -1,4 +0,0 @@
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.persistence import persist
|
||||
|
||||
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||
@@ -1,129 +0,0 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
@@ -1,204 +0,0 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(
|
||||
e
|
||||
) or "does not exist" in str(e):
|
||||
# Ignore readonly database and collection not found errors (already reset)
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
) from e
|
||||
@@ -1 +0,0 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Base embeddings callable utilities for RAG systems."""
|
||||
|
||||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.rag.core.types import (
|
||||
Embeddable,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
PyEmbedding,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
|
||||
) -> Embeddings | None:
|
||||
"""Normalize various embedding formats to a standard list of numpy arrays.
|
||||
|
||||
Args:
|
||||
target: Input embeddings in various formats (list of floats, list of lists,
|
||||
numpy array, or list of numpy arrays).
|
||||
|
||||
Returns:
|
||||
Normalized embeddings as a list of numpy arrays, or None if input is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings are empty or in an unsupported format.
|
||||
"""
|
||||
if isinstance(target, np.ndarray):
|
||||
if target.ndim == 1:
|
||||
return [target.astype(np.float32)]
|
||||
if target.ndim == 2:
|
||||
return [row.astype(np.float32) for row in target]
|
||||
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
|
||||
|
||||
first = target[0]
|
||||
if isinstance(first, (int, float)) and not isinstance(first, bool):
|
||||
return [np.array(target, dtype=np.float32)]
|
||||
if isinstance(first, list):
|
||||
return [np.array(emb, dtype=np.float32) for emb in target]
|
||||
if isinstance(first, np.ndarray):
|
||||
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
|
||||
|
||||
raise ValueError(f"Unsupported embeddings format: {type(first)}")
|
||||
|
||||
|
||||
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
|
||||
"""Cast a single item to a list if needed.
|
||||
|
||||
Args:
|
||||
target: A single item or list of items.
|
||||
|
||||
Returns:
|
||||
A list of items or None if input is None.
|
||||
"""
|
||||
if target is None:
|
||||
return None
|
||||
return target if isinstance(target, list) else [target]
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
"""Validate embeddings format and content.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays to validate.
|
||||
|
||||
Returns:
|
||||
Validated embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings format or content is invalid.
|
||||
"""
|
||||
if not isinstance(embeddings, list):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
|
||||
)
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
|
||||
)
|
||||
if not all(isinstance(e, np.ndarray) for e in embeddings):
|
||||
raise ValueError(
|
||||
"Expected each embedding in the embeddings to be a numpy array"
|
||||
)
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if embedding.ndim == 0:
|
||||
raise ValueError(
|
||||
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
|
||||
)
|
||||
if embedding.size == 0:
|
||||
raise ValueError(
|
||||
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
|
||||
f"Got an array with no values at position {i}"
|
||||
)
|
||||
if not all(
|
||||
isinstance(value, (np.integer, float, np.floating))
|
||||
and not isinstance(value, bool)
|
||||
for value in embedding
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingFunction(Protocol[D]):
|
||||
"""Protocol for embedding functions.
|
||||
|
||||
Embedding functions convert input data (documents or images) into vector embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
"""Convert input data to embeddings.
|
||||
|
||||
Args:
|
||||
input: Input data to embed (documents or images).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Wrap __call__ method to normalize and validate embeddings."""
|
||||
super().__init_subclass__()
|
||||
original_call = cls.__call__
|
||||
|
||||
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
||||
result = original_call(self, input)
|
||||
if result is None:
|
||||
raise ValueError("Embedding function returned None")
|
||||
normalized = normalize_embeddings(result)
|
||||
if normalized is None:
|
||||
raise ValueError("Normalization returned None for non-None input")
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a query input.
|
||||
This method is optional, and if not implemented, the default behavior is to call __call__.
|
||||
"""
|
||||
return self.__call__(input=input)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
"""Abstract base class for embedding providers.
|
||||
|
||||
This class provides a common interface for dynamically loading and building
|
||||
embedding functions from various providers.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
|
||||
embedding_callable: type[T] = Field(
|
||||
..., description="The embedding function class to use"
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
from numpy.typing import NDArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
PyEmbedding = Sequence[float] | Sequence[int]
|
||||
PyEmbeddings = list[PyEmbedding]
|
||||
Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
IntegerType = TypeVar("IntegerType", bound=integer)
|
||||
FloatingType = TypeVar("FloatingType", bound=floating)
|
||||
NumberType = TypeVar("NumberType", bound=number)
|
||||
|
||||
DType32 = TypeVar("DType32", np.int32, np.float32)
|
||||
DType64 = TypeVar("DType64", np.int64, np.float64)
|
||||
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)
|
||||
@@ -1 +0,0 @@
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
@@ -1,392 +0,0 @@
|
||||
"""Factory functions for creating embedding providers and functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.utilities.import_utils import import_and_validate_definition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
PROVIDER_PATHS = {
|
||||
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
|
||||
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
|
||||
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
|
||||
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
|
||||
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
|
||||
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
|
||||
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
|
||||
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
|
||||
}
|
||||
|
||||
|
||||
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
|
||||
"""Build an embedding function instance from a provider.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider configuration.
|
||||
|
||||
Returns:
|
||||
An instance of the specified embedding function type.
|
||||
"""
|
||||
return provider.embedding_callable(
|
||||
**provider.model_dump(exclude={"embedding_callable"})
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: BedrockProviderSpec,
|
||||
) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: HuggingFaceProviderSpec,
|
||||
) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VoyageAIProviderSpec,
|
||||
) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: InstructorProviderSpec,
|
||||
) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: RoboflowProviderSpec,
|
||||
) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: OpenCLIPProviderSpec,
|
||||
) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: Text2VecProviderSpec,
|
||||
) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not recognized.
|
||||
"""
|
||||
provider_name = spec["provider"]
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name == "watson":
|
||||
warnings.warn(
|
||||
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
|
||||
'Use "watsonx" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
)
|
||||
|
||||
provider_path = PROVIDER_PATHS[provider_name]
|
||||
try:
|
||||
provider_class = import_and_validate_definition(provider_path)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
provider = provider_class(**provider_config)
|
||||
return build_embedder_from_provider(provider)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder(spec):
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
Args:
|
||||
spec: Either a provider specification dictionary or a provider instance.
|
||||
|
||||
Returns:
|
||||
An embedding function instance. If a typed provider is passed, returns
|
||||
the specific embedding function type.
|
||||
|
||||
Examples:
|
||||
# From dictionary specification
|
||||
embedder = build_embedder({
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "sk-..."}
|
||||
})
|
||||
|
||||
# From provider instance
|
||||
provider = OpenAIProvider(api_key="sk-...")
|
||||
embedder = build_embedder(provider)
|
||||
"""
|
||||
if isinstance(spec, BaseEmbeddingsProvider):
|
||||
return build_embedder_from_provider(spec)
|
||||
return build_embedder_from_dict(spec)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
get_embedding_function = build_embedder
|
||||
@@ -1 +0,0 @@
|
||||
"""Embedding provider implementations."""
|
||||
@@ -1,13 +0,0 @@
|
||||
"""AWS embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import (
|
||||
BedrockProviderConfig,
|
||||
BedrockProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BedrockProvider",
|
||||
"BedrockProviderConfig",
|
||||
"BedrockProviderSpec",
|
||||
]
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
boto3.Session: AWS session object
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
|
||||
default=AmazonBedrockEmbeddingFunction,
|
||||
description="Amazon Bedrock embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Bedrock provider."""
|
||||
|
||||
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Cohere embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.cohere.types import (
|
||||
CohereProviderConfig,
|
||||
CohereProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CohereProvider",
|
||||
"CohereProviderConfig",
|
||||
"CohereProviderSpec",
|
||||
]
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
embedding_callable: type[CohereEmbeddingFunction] = Field(
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Cohere provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Custom embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
|
||||
from crewai.rag.embeddings.providers.custom.types import (
|
||||
CustomProviderConfig,
|
||||
CustomProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CustomProvider",
|
||||
"CustomProviderConfig",
|
||||
"CustomProviderSpec",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.custom.embedding_callable import (
|
||||
CustomEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
embedding_callable: type[CustomEmbeddingFunction] = Field(
|
||||
..., description="Custom embedding function class"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Custom embedding function base implementation."""
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
|
||||
|
||||
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Base class for custom embedding functions.
|
||||
|
||||
This provides a concrete implementation that can be subclassed for custom embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Convert input documents to embeddings.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement __call__ method")
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderConfig,
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderConfig,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.vertex import (
|
||||
VertexAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
]
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderConfig,
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
@@ -1,17 +0,0 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderConfig,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watsonx import (
|
||||
WatsonXProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProviderSpec",
|
||||
"WatsonXProvider",
|
||||
"WatsonXProviderConfig",
|
||||
"WatsonXProviderSpec",
|
||||
]
|
||||
@@ -1,159 +0,0 @@
|
||||
"""IBM WatsonX embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||
|
||||
|
||||
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM WatsonX models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
|
||||
"""Initialize WatsonX embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._config = kwargs
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "watsonx"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watsonx embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
if "params" in self._config and self._config["params"] is not None:
|
||||
embeddings_config["params"] = self._config["params"]
|
||||
if "project_id" in self._config and self._config["project_id"] is not None:
|
||||
embeddings_config["project_id"] = self._config["project_id"]
|
||||
if "space_id" in self._config and self._config["space_id"] is not None:
|
||||
embeddings_config["space_id"] = self._config["space_id"]
|
||||
if "api_client" in self._config and self._config["api_client"] is not None:
|
||||
embeddings_config["api_client"] = self._config["api_client"]
|
||||
if "verify" in self._config and self._config["verify"] is not None:
|
||||
embeddings_config["verify"] = self._config["verify"]
|
||||
if "persistent_connection" in self._config:
|
||||
embeddings_config["persistent_connection"] = self._config[
|
||||
"persistent_connection"
|
||||
]
|
||||
if "batch_size" in self._config:
|
||||
embeddings_config["batch_size"] = self._config["batch_size"]
|
||||
if "concurrency_limit" in self._config:
|
||||
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
|
||||
if "max_retries" in self._config and self._config["max_retries"] is not None:
|
||||
embeddings_config["max_retries"] = self._config["max_retries"]
|
||||
if "delay_time" in self._config and self._config["delay_time"] is not None:
|
||||
embeddings_config["delay_time"] = self._config["delay_time"]
|
||||
if (
|
||||
"retry_status_codes" in self._config
|
||||
and self._config["retry_status_codes"] is not None
|
||||
):
|
||||
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
|
||||
|
||||
if "credentials" in self._config and self._config["credentials"] is not None:
|
||||
embeddings_config["credentials"] = self._config["credentials"]
|
||||
else:
|
||||
cred_config: dict = {}
|
||||
if "url" in self._config and self._config["url"] is not None:
|
||||
cred_config["url"] = self._config["url"]
|
||||
if "api_key" in self._config and self._config["api_key"] is not None:
|
||||
cred_config["api_key"] = self._config["api_key"]
|
||||
if "name" in self._config and self._config["name"] is not None:
|
||||
cred_config["name"] = self._config["name"]
|
||||
if (
|
||||
"iam_serviceid_crn" in self._config
|
||||
and self._config["iam_serviceid_crn"] is not None
|
||||
):
|
||||
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
|
||||
if (
|
||||
"trusted_profile_id" in self._config
|
||||
and self._config["trusted_profile_id"] is not None
|
||||
):
|
||||
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
|
||||
if "token" in self._config and self._config["token"] is not None:
|
||||
cred_config["token"] = self._config["token"]
|
||||
if (
|
||||
"projects_token" in self._config
|
||||
and self._config["projects_token"] is not None
|
||||
):
|
||||
cred_config["projects_token"] = self._config["projects_token"]
|
||||
if "username" in self._config and self._config["username"] is not None:
|
||||
cred_config["username"] = self._config["username"]
|
||||
if "password" in self._config and self._config["password"] is not None:
|
||||
cred_config["password"] = self._config["password"]
|
||||
if (
|
||||
"instance_id" in self._config
|
||||
and self._config["instance_id"] is not None
|
||||
):
|
||||
cred_config["instance_id"] = self._config["instance_id"]
|
||||
if "version" in self._config and self._config["version"] is not None:
|
||||
cred_config["version"] = self._config["version"]
|
||||
if (
|
||||
"bedrock_url" in self._config
|
||||
and self._config["bedrock_url"] is not None
|
||||
):
|
||||
cred_config["bedrock_url"] = self._config["bedrock_url"]
|
||||
if (
|
||||
"platform_url" in self._config
|
||||
and self._config["platform_url"] is not None
|
||||
):
|
||||
cred_config["platform_url"] = self._config["platform_url"]
|
||||
if "proxies" in self._config and self._config["proxies"] is not None:
|
||||
cred_config["proxies"] = self._config["proxies"]
|
||||
if (
|
||||
"verify" not in embeddings_config
|
||||
and "verify" in self._config
|
||||
and self._config["verify"] is not None
|
||||
):
|
||||
cred_config["verify"] = self._config["verify"]
|
||||
|
||||
if cred_config:
|
||||
embeddings_config["credentials"] = Credentials(**cred_config)
|
||||
|
||||
if "params" not in embeddings_config:
|
||||
embeddings_config["params"] = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(**embeddings_config)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during WatsonX embedding: {e}")
|
||||
raise
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Type definitions for IBM WatsonX embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict, deprecated
|
||||
|
||||
|
||||
class WatsonXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for WatsonX provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
params: dict[str, str | dict[str, str]]
|
||||
credentials: Any
|
||||
project_id: str
|
||||
space_id: str
|
||||
api_client: Any
|
||||
verify: bool | str
|
||||
persistent_connection: Annotated[bool, True]
|
||||
batch_size: Annotated[int, 100]
|
||||
concurrency_limit: Annotated[int, 10]
|
||||
max_retries: int
|
||||
delay_time: float
|
||||
retry_status_codes: list[int]
|
||||
api_key: str
|
||||
name: str
|
||||
iam_serviceid_crn: str
|
||||
trusted_profile_id: str
|
||||
token: str
|
||||
projects_token: str
|
||||
username: str
|
||||
password: str
|
||||
instance_id: str
|
||||
version: str
|
||||
bedrock_url: str
|
||||
platform_url: str
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonXProviderSpec(TypedDict, total=False):
|
||||
"""WatsonX provider specification."""
|
||||
|
||||
provider: Required[Literal["watsonx"]]
|
||||
config: WatsonXProviderConfig
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification (deprecated).
|
||||
|
||||
Notes:
|
||||
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
|
||||
"""
|
||||
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonXProviderConfig
|
||||
@@ -1,142 +0,0 @@
|
||||
"""IBM WatsonX embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
"""IBM WatsonX embeddings provider.
|
||||
|
||||
Note: Requires custom implementation as WatsonX uses a different interface.
|
||||
"""
|
||||
|
||||
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="WatsonX credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
"""Validate that either space_id or project_id is provided."""
|
||||
if not self.space_id and not self.project_id:
|
||||
raise ValueError("One of 'space_id' or 'project_id' must be provided")
|
||||
return self
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Instructor embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
|
||||
InstructorProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import (
|
||||
InstructorProviderConfig,
|
||||
InstructorProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InstructorProvider",
|
||||
"InstructorProviderConfig",
|
||||
"InstructorProviderSpec",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
embedding_callable: type[InstructorEmbeddingFunction] = Field(
|
||||
default=InstructorEmbeddingFunction,
|
||||
description="Instructor embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Instructor provider."""
|
||||
|
||||
model_name: Annotated[str, "hkunlp/instructor-base"]
|
||||
device: Annotated[str, "cpu"]
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict, total=False):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Required[Literal["instructor"]]
|
||||
config: InstructorProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Jina embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.jina.types import (
|
||||
JinaProviderConfig,
|
||||
JinaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JinaProvider",
|
||||
"JinaProviderConfig",
|
||||
"JinaProviderSpec",
|
||||
]
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Jina provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict, total=False):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Required[Literal["jina"]]
|
||||
config: JinaProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Microsoft embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import (
|
||||
AzureProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.microsoft.types import (
|
||||
AzureProviderConfig,
|
||||
AzureProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureProvider",
|
||||
"AzureProviderConfig",
|
||||
"AzureProviderSpec",
|
||||
]
|
||||
@@ -1,60 +0,0 @@
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Azure provider."""
|
||||
|
||||
api_key: str
|
||||
api_base: str
|
||||
api_type: Annotated[str, "azure"]
|
||||
api_version: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict, total=False):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Required[Literal["azure"]]
|
||||
config: AzureProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Ollama embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
|
||||
OllamaProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ollama.types import (
|
||||
OllamaProviderConfig,
|
||||
OllamaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OllamaProvider",
|
||||
"OllamaProviderConfig",
|
||||
"OllamaProviderSpec",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
embedding_callable: type[OllamaEmbeddingFunction] = Field(
|
||||
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
|
||||
)
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Ollama provider."""
|
||||
|
||||
url: Annotated[str, "http://localhost:11434/api/embeddings"]
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict, total=False):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Required[Literal["ollama"]]
|
||||
config: OllamaProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""ONNX embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
|
||||
from crewai.rag.embeddings.providers.onnx.types import (
|
||||
ONNXProviderConfig,
|
||||
ONNXProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ONNXProvider",
|
||||
"ONNXProviderConfig",
|
||||
"ONNXProviderSpec",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
|
||||
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
|
||||
)
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for ONNX provider."""
|
||||
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict, total=False):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Required[Literal["onnx"]]
|
||||
config: ONNXProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""OpenAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import (
|
||||
OpenAIProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openai.types import (
|
||||
OpenAIProviderConfig,
|
||||
OpenAIProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAIProvider",
|
||||
"OpenAIProviderConfig",
|
||||
"OpenAIProviderSpec",
|
||||
]
|
||||
@@ -1,62 +0,0 @@
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenAI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
api_base: str
|
||||
api_type: str
|
||||
api_version: str
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict, total=False):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Required[Literal["openai"]]
|
||||
config: OpenAIProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""OpenCLIP embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
|
||||
OpenCLIPProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openclip.types import (
|
||||
OpenCLIPProviderConfig,
|
||||
OpenCLIPProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenCLIPProvider",
|
||||
"OpenCLIPProviderConfig",
|
||||
"OpenCLIPProviderSpec",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
|
||||
default=OpenCLIPEmbeddingFunction,
|
||||
description="OpenCLIP embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenCLIP provider."""
|
||||
|
||||
model_name: Annotated[str, "ViT-B-32"]
|
||||
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
|
||||
device: Annotated[str, "cpu"]
|
||||
|
||||
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Required[Literal["openclip"]]
|
||||
config: OpenCLIPProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Roboflow embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
|
||||
RoboflowProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.roboflow.types import (
|
||||
RoboflowProviderConfig,
|
||||
RoboflowProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RoboflowProvider",
|
||||
"RoboflowProviderConfig",
|
||||
"RoboflowProviderSpec",
|
||||
]
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
|
||||
default=RoboflowEmbeddingFunction,
|
||||
description="Roboflow embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Roboflow provider."""
|
||||
|
||||
api_key: Annotated[str, ""]
|
||||
api_url: Annotated[str, "https://infer.roboflow.com"]
|
||||
|
||||
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Required[Literal["roboflow"]]
|
||||
config: RoboflowProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""SentenceTransformer embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderConfig,
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SentenceTransformerProvider",
|
||||
"SentenceTransformerProviderConfig",
|
||||
"SentenceTransformerProviderSpec",
|
||||
]
|
||||
@@ -1,34 +0,0 @@
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class SentenceTransformerProvider(
|
||||
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
|
||||
):
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
|
||||
default=SentenceTransformerEmbeddingFunction,
|
||||
description="SentenceTransformer embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for SentenceTransformer provider."""
|
||||
|
||||
model_name: Annotated[str, "all-MiniLM-L6-v2"]
|
||||
device: Annotated[str, "cpu"]
|
||||
normalize_embeddings: Annotated[bool, False]
|
||||
|
||||
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Required[Literal["sentence-transformer"]]
|
||||
config: SentenceTransformerProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Text2Vec embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
|
||||
Text2VecProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import (
|
||||
Text2VecProviderConfig,
|
||||
Text2VecProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Text2VecProvider",
|
||||
"Text2VecProviderConfig",
|
||||
"Text2VecProviderSpec",
|
||||
]
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
|
||||
default=Text2VecEmbeddingFunction,
|
||||
description="Text2Vec embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Text2Vec provider."""
|
||||
|
||||
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
|
||||
|
||||
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Required[Literal["text2vec"]]
|
||||
config: Text2VecProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""VoyageAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import (
|
||||
VoyageAIProviderConfig,
|
||||
VoyageAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
|
||||
VoyageAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VoyageAIProvider",
|
||||
"VoyageAIProviderConfig",
|
||||
"VoyageAIProviderSpec",
|
||||
]
|
||||
@@ -1,62 +0,0 @@
|
||||
"""VoyageAI embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for VoyageAI models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
|
||||
"""Initialize VoyageAI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"voyageai is required for voyageai embeddings. "
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "voyageai"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
result = self._client.embed(
|
||||
texts=input,
|
||||
model=self._config.get("model", "voyage-2"),
|
||||
input_type=self._config.get("input_type"),
|
||||
truncation=self._config.get("truncation", True),
|
||||
output_dtype=self._config.get("output_dtype"),
|
||||
output_dimension=self._config.get("output_dimension"),
|
||||
)
|
||||
|
||||
return cast(Embeddings, result.embeddings)
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for VoyageAI provider."""
|
||||
|
||||
api_key: str
|
||||
model: Annotated[str, "voyage-2"]
|
||||
input_type: str
|
||||
truncation: Annotated[bool, True]
|
||||
output_dtype: str
|
||||
output_dimension: int
|
||||
max_retries: Annotated[int, 0]
|
||||
timeout: float
|
||||
|
||||
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Required[Literal["voyageai"]]
|
||||
config: VoyageAIProviderConfig
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
|
||||
default=VoyageAIEmbeddingFunction,
|
||||
description="Voyage AI embedding function class",
|
||||
)
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
ProviderSpec = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
| CustomProviderSpec
|
||||
| GenerativeAiProviderSpec
|
||||
| HuggingFaceProviderSpec
|
||||
| InstructorProviderSpec
|
||||
| JinaProviderSpec
|
||||
| OllamaProviderSpec
|
||||
| ONNXProviderSpec
|
||||
| OpenAIProviderSpec
|
||||
| OpenCLIPProviderSpec
|
||||
| RoboflowProviderSpec
|
||||
| SentenceTransformerProviderSpec
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
|
||||
| WatsonXProviderSpec
|
||||
)
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watsonx",
|
||||
"watson", # for backward compatibility until v1.0.0
|
||||
]
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
@@ -1 +0,0 @@
|
||||
"""Storage components for RAG infrastructure."""
|
||||
@@ -1,24 +0,0 @@
|
||||
from crewai.utilities.converter import Converter, ConverterError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.file_handler import FileHandler
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import Prompts
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
|
||||
__all__ = [
|
||||
"I18N",
|
||||
"Converter",
|
||||
"ConverterError",
|
||||
"FileHandler",
|
||||
"InternalInstructor",
|
||||
"LLMContextLengthExceededError",
|
||||
"Logger",
|
||||
"Printer",
|
||||
"Prompts",
|
||||
"RPMController",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Annotated, Final
|
||||
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
|
||||
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
|
||||
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
|
||||
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
|
||||
MAX_FILE_NAME_LENGTH: Final[int] = 255
|
||||
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
|
||||
|
||||
|
||||
class _NotSpecified:
|
||||
"""Sentinel class to detect when no value has been explicitly provided.
|
||||
|
||||
Notes:
|
||||
- TODO: Consider moving this class and NOT_SPECIFIED to types.py
|
||||
as they are more type-related constructs than business constants.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_SPECIFIED"
|
||||
|
||||
|
||||
NOT_SPECIFIED: Final[
|
||||
Annotated[
|
||||
_NotSpecified,
|
||||
"Sentinel value used to detect when no value has been explicitly provided. "
|
||||
"Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` "
|
||||
"allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.",
|
||||
]
|
||||
] = _NotSpecified()
|
||||
CREWAI_BASE_URL: Final[str] = "https://app.crewai.com"
|
||||
@@ -1,448 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
|
||||
|
||||
|
||||
class ConverterError(Exception):
|
||||
"""Error raised when Converter fails to parse the input."""
|
||||
|
||||
def __init__(self, message: str, *args: object) -> None:
|
||||
"""Initialize the ConverterError with a message.
|
||||
|
||||
Args:
|
||||
message: The error message.
|
||||
*args: Additional arguments for the base Exception class.
|
||||
"""
|
||||
super().__init__(message, *args)
|
||||
self.message = message
|
||||
|
||||
|
||||
class Converter(OutputConverter):
|
||||
"""Class that converts text into either pydantic or json."""
|
||||
|
||||
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
|
||||
"""Convert text to pydantic.
|
||||
|
||||
Args:
|
||||
current_attempt: The current attempt number for conversion retries.
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel instance.
|
||||
|
||||
Raises:
|
||||
ConverterError: If conversion fails after maximum attempts.
|
||||
"""
|
||||
try:
|
||||
if self.llm.supports_function_calling():
|
||||
result = self._create_instructor().to_pydantic()
|
||||
else:
|
||||
response = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
]
|
||||
)
|
||||
try:
|
||||
# Try to directly validate the response JSON
|
||||
result = self.model.model_validate_json(response)
|
||||
except ValidationError:
|
||||
# If direct validation fails, attempt to extract valid JSON
|
||||
result = handle_partial_json(
|
||||
result=response,
|
||||
model=self.model,
|
||||
is_json_output=False,
|
||||
agent=None,
|
||||
)
|
||||
# Ensure result is a BaseModel instance
|
||||
if not isinstance(result, BaseModel):
|
||||
if isinstance(result, dict):
|
||||
result = self.model.model_validate(result)
|
||||
elif isinstance(result, str):
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
result = self.model.model_validate(parsed)
|
||||
except Exception as parse_err:
|
||||
raise ConverterError(
|
||||
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
|
||||
) from parse_err
|
||||
else:
|
||||
raise ConverterError(
|
||||
"handle_partial_json returned an unexpected type."
|
||||
) from None
|
||||
return result
|
||||
except ValidationError as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_pydantic(current_attempt + 1)
|
||||
raise ConverterError(
|
||||
f"Failed to convert text into a Pydantic model due to validation error: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_pydantic(current_attempt + 1)
|
||||
raise ConverterError(
|
||||
f"Failed to convert text into a Pydantic model due to error: {e}"
|
||||
) from e
|
||||
|
||||
def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override]
|
||||
"""Convert text to json.
|
||||
|
||||
Args:
|
||||
current_attempt: The current attempt number for conversion retries.
|
||||
|
||||
Returns:
|
||||
A JSON string or ConverterError if conversion fails.
|
||||
|
||||
Raises:
|
||||
ConverterError: If conversion fails after maximum attempts.
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.llm.supports_function_calling():
|
||||
return self._create_instructor().to_json()
|
||||
return json.dumps(
|
||||
self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_json(current_attempt + 1)
|
||||
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
|
||||
|
||||
def _create_instructor(self) -> InternalInstructor:
|
||||
"""Create an instructor."""
|
||||
|
||||
return InternalInstructor(
|
||||
llm=self.llm,
|
||||
model=self.model,
|
||||
content=self.text,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_model(
|
||||
result: str,
|
||||
output_pydantic: type[BaseModel] | None,
|
||||
output_json: type[BaseModel] | None,
|
||||
agent: Agent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON.
|
||||
|
||||
Args:
|
||||
result: The result string to convert.
|
||||
output_pydantic: The Pydantic model class to convert to.
|
||||
output_json: The Pydantic model class to convert to JSON.
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
"""
|
||||
model = output_pydantic or output_json
|
||||
if model is None:
|
||||
return result
|
||||
try:
|
||||
escaped_result = json.dumps(json.loads(result, strict=False))
|
||||
return validate_model(
|
||||
result=escaped_result, model=model, is_json_output=bool(output_json)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
return handle_partial_json(
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=bool(output_json),
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
except ValidationError:
|
||||
return handle_partial_json(
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=bool(output_json),
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
Printer().print(
|
||||
content=f"Unexpected error during model conversion: {type(e).__name__}: {e}. Returning original result.",
|
||||
color="red",
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def validate_model(
|
||||
result: str, model: type[BaseModel], is_json_output: bool
|
||||
) -> dict[str, Any] | BaseModel:
|
||||
"""Validate and convert a JSON string to a Pydantic model or dict.
|
||||
|
||||
Args:
|
||||
result: The JSON string to validate and convert.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
|
||||
Returns:
|
||||
The converted result as a dict or BaseModel.
|
||||
"""
|
||||
exported_result = model.model_validate_json(result)
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
return exported_result
|
||||
|
||||
|
||||
def handle_partial_json(
|
||||
result: str,
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Agent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
|
||||
|
||||
Args:
|
||||
result: The result string to process.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
"""
|
||||
match = _JSON_PATTERN.search(result)
|
||||
if match:
|
||||
try:
|
||||
exported_result = model.model_validate_json(match.group())
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
return exported_result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except ValidationError:
|
||||
pass
|
||||
except Exception as e:
|
||||
Printer().print(
|
||||
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return convert_with_instructions(
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=is_json_output,
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
|
||||
def convert_with_instructions(
|
||||
result: str,
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Agent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON using instructions.
|
||||
|
||||
Args:
|
||||
result: The result string to convert.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither agent nor converter_cls is provided.
|
||||
|
||||
Notes:
|
||||
- TODO: Fix llm typing issues, return llm should not be able to be str or None.
|
||||
"""
|
||||
if agent is None:
|
||||
raise TypeError("Agent must be provided if converter_cls is not specified.")
|
||||
llm = agent.function_calling_llm or agent.llm
|
||||
instructions = get_conversion_instructions(model=model, llm=llm)
|
||||
converter = create_converter(
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
llm=llm,
|
||||
text=result,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
)
|
||||
exported_result = (
|
||||
converter.to_pydantic() if not is_json_output else converter.to_json()
|
||||
)
|
||||
|
||||
if isinstance(exported_result, ConverterError):
|
||||
Printer().print(
|
||||
content=f"{exported_result.message} Using raw output instead.",
|
||||
color="red",
|
||||
)
|
||||
return result
|
||||
|
||||
return exported_result
|
||||
|
||||
|
||||
def get_conversion_instructions(
|
||||
model: type[BaseModel], llm: BaseLLM | LLM | str
|
||||
) -> str:
|
||||
"""Generate conversion instructions based on the model and LLM capabilities.
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
llm: The language model instance.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
instructions = "Please convert the following text into valid JSON."
|
||||
if (
|
||||
llm
|
||||
and not isinstance(llm, str)
|
||||
and hasattr(llm, "supports_function_calling")
|
||||
and llm.supports_function_calling()
|
||||
):
|
||||
model_schema = PydanticSchemaParser(model=model).get_schema()
|
||||
instructions += (
|
||||
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
|
||||
f"The JSON must follow this schema exactly:\n```json\n{model_schema}\n```"
|
||||
)
|
||||
else:
|
||||
model_description = generate_model_description(model)
|
||||
instructions += (
|
||||
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
|
||||
f"The JSON must follow this format exactly:\n{model_description}"
|
||||
)
|
||||
return instructions
|
||||
|
||||
|
||||
class CreateConverterKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for creating a converter.
|
||||
|
||||
Attributes:
|
||||
llm: The language model instance.
|
||||
text: The text to convert.
|
||||
model: The Pydantic model class.
|
||||
instructions: The conversion instructions.
|
||||
"""
|
||||
|
||||
llm: BaseLLM | LLM | str
|
||||
text: str
|
||||
model: type[BaseModel]
|
||||
instructions: str
|
||||
|
||||
|
||||
def create_converter(
|
||||
agent: Agent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Unpack[CreateConverterKwargs],
|
||||
) -> Converter:
|
||||
"""Create a converter instance based on the agent or provided class.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to instantiate.
|
||||
*args: The positional arguments to pass to the converter.
|
||||
**kwargs: The keyword arguments to pass to the converter.
|
||||
|
||||
Returns:
|
||||
An instance of the specified converter class.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither agent nor converter_cls is provided.
|
||||
AttributeError: If the agent does not have a 'get_output_converter' method.
|
||||
Exception: If no converter instance is created.
|
||||
|
||||
"""
|
||||
if agent and not converter_cls:
|
||||
if hasattr(agent, "get_output_converter"):
|
||||
converter = agent.get_output_converter(*args, **kwargs)
|
||||
else:
|
||||
raise AttributeError("Agent does not have a 'get_output_converter' method")
|
||||
elif converter_cls:
|
||||
converter = converter_cls(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Either agent or converter_cls must be provided")
|
||||
|
||||
if not converter:
|
||||
raise Exception("No output converter found or set.")
|
||||
|
||||
return converter
|
||||
|
||||
|
||||
def generate_model_description(model: type[BaseModel]) -> str:
|
||||
"""Generate a string description of a Pydantic model's fields and their types.
|
||||
|
||||
This function takes a Pydantic model class and returns a string that describes
|
||||
the model's fields and their respective types. The description includes handling
|
||||
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
|
||||
models.
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
|
||||
Returns:
|
||||
A string representation of the model's fields and types.
|
||||
"""
|
||||
|
||||
def describe_field(field_type: Any) -> str:
|
||||
"""Recursively describe a field's type.
|
||||
|
||||
Args:
|
||||
field_type: The type of the field to describe.
|
||||
|
||||
Returns:
|
||||
A string representation of the field's type.
|
||||
"""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is Union or (origin is None and len(args) > 0):
|
||||
# Handle both Union and the new '|' syntax
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return f"Optional[{describe_field(non_none_args[0])}]"
|
||||
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||
if origin is list:
|
||||
return f"List[{describe_field(args[0])}]"
|
||||
if origin is dict:
|
||||
key_type = describe_field(args[0])
|
||||
value_type = describe_field(args[1])
|
||||
return f"Dict[{key_type}, {value_type}]"
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return generate_model_description(field_type)
|
||||
if hasattr(field_type, "__name__"):
|
||||
return field_type.__name__
|
||||
return str(field_type)
|
||||
|
||||
fields = model.model_fields
|
||||
field_descriptions = [
|
||||
f'"{name}": {describe_field(field.annotation)}'
|
||||
for name, field in fields.items()
|
||||
]
|
||||
return "{\n " + ",\n ".join(field_descriptions) + "\n}"
|
||||
@@ -1 +0,0 @@
|
||||
"""Crew-specific utilities."""
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Models for crew-related data structures."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrewContext(BaseModel):
|
||||
"""Model representing crew context information.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the crew.
|
||||
key: Optional crew key/name for identification.
|
||||
"""
|
||||
|
||||
id: str | None = Field(default=None, description="Unique identifier for the crew")
|
||||
key: str | None = Field(
|
||||
default=None, description="Optional crew key/name for identification"
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
from typing import Final
|
||||
|
||||
CONTEXT_LIMIT_ERRORS: Final[list[str]] = [
|
||||
"expected a string with maximum length",
|
||||
"maximum context length",
|
||||
"context length exceeded",
|
||||
"context_length_exceeded",
|
||||
"context window full",
|
||||
"too many tokens",
|
||||
"input is too long",
|
||||
"exceeds token limit",
|
||||
]
|
||||
|
||||
|
||||
class LLMContextLengthExceededError(Exception):
|
||||
"""Exception raised when the context length of a language model is exceeded.
|
||||
|
||||
Attributes:
|
||||
original_error_message: The original error message from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, error_message: str) -> None:
|
||||
"""Initialize the exception with the original error message.
|
||||
|
||||
Args:
|
||||
error_message: The original error message from the LLM.
|
||||
"""
|
||||
self.original_error_message = error_message
|
||||
super().__init__(self._get_error_message(error_message))
|
||||
|
||||
@staticmethod
|
||||
def _is_context_limit_error(error_message: str) -> bool:
|
||||
"""Check if the error message indicates a context length limit error.
|
||||
|
||||
Args:
|
||||
error_message: The error message to check.
|
||||
|
||||
Returns:
|
||||
True if the error message indicates a context length limit error, False otherwise.
|
||||
"""
|
||||
return any(
|
||||
phrase.lower() in error_message.lower() for phrase in CONTEXT_LIMIT_ERRORS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_error_message(error_message: str) -> str:
|
||||
"""Generate a user-friendly error message based on the original error message.
|
||||
|
||||
Args:
|
||||
error_message: The original error message from the LLM.
|
||||
|
||||
Returns:
|
||||
A user-friendly error message.
|
||||
"""
|
||||
return (
|
||||
f"LLM context length exceeded. Original error: {error_message}\n"
|
||||
"Consider using a smaller input or implementing a text splitting strategy."
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user