Compare commits

...

36 Commits

Author SHA1 Message Date
Devin AI
411285f5ef fix: TypedDict compatibility for Python 3.11 and remove unused imports
- Use typing_extensions.TypedDict instead of typing.TypedDict for Python < 3.12 compatibility
- Remove unused pytest import from test_config.py
- Remove unused sys import from test_factory.py
- Fixes Pydantic error: 'Please use typing_extensions.TypedDict instead of typing.TypedDict on Python < 3.12'

Co-Authored-By: João <joao@crewai.com>
2025-08-27 01:23:10 +00:00
Devin AI
dce26e8276 fix: Address CI failures - type annotations, lint, security
- Fix TypeAlias annotation in elasticsearch/types.py using TYPE_CHECKING
- Add 'elasticsearch' to _MissingProvider Literal type in base.py
- Remove unused variable in test_client.py
- Add usedforsecurity=False to MD5 hash in config.py for security check

Co-Authored-By: João <joao@crewai.com>
2025-08-27 01:17:40 +00:00
Devin AI
e3a575920c feat: Add comprehensive Elasticsearch support to crewai.rag
- Implement ElasticsearchClient with full sync/async operations
- Add ElasticsearchConfig with connection and embedding options
- Create factory pattern following ChromaDB/Qdrant conventions
- Add comprehensive test suite with 26 passing tests (100% coverage)
- Support both sync and async Elasticsearch operations
- Include proper error handling and edge case coverage
- Update type system and factory to support Elasticsearch provider
- Follow existing RAG patterns for consistency

Resolves #3404

Co-Authored-By: João <joao@crewai.com>
2025-08-27 01:07:57 +00:00
Lucas Gomide
88d2968fd5 chore: add deprecation notices to Task.max_retries (#3379)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-08-26 17:24:58 -04:00
Lorenze Jay
7addda9398 Lorenze/better tracing events (#3382)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* feat: implement tool usage limit exception handling

- Introduced `ToolUsageLimitExceeded` exception to manage maximum usage limits for tools.
- Enhanced `CrewStructuredTool` to check and raise this exception when the usage limit is reached.
- Updated `_run` and `_execute` methods to include usage limit checks and handle exceptions appropriately, improving reliability and user feedback.

* feat: enhance PlusAPI and ToolUsage with task metadata

- Removed the `send_trace_batch` method from PlusAPI to streamline the API.
- Added timeout parameters to trace event methods in PlusAPI for improved reliability.
- Updated ToolUsage to include task metadata (task name and ID) in event emissions, enhancing traceability and context during tool usage.
- Refactored event handling in LLM and ToolUsage events to ensure task information is consistently captured.

* feat: enhance memory and event handling with task and agent metadata

- Added task and agent metadata to various memory and event classes, improving traceability and context during memory operations.
- Updated the `ContextualMemory` and `Memory` classes to associate tasks and agents, allowing for better context management.
- Enhanced event emissions in `LLM`, `ToolUsage`, and memory events to include task and agent information, facilitating improved debugging and monitoring.
- Refactored event handling to ensure consistent capture of task and agent details across the system.

* drop

* refactor: clean up unused imports in memory and event modules

- Removed unused TYPE_CHECKING imports from long_term_memory.py to streamline the code.
- Eliminated unnecessary import from memory_events.py, enhancing clarity and maintainability.

* fix memory tests

* fix task_completed payload

* fix: remove unused test agent variable in external memory tests

* refactor: remove unused agent parameter from Memory class save method

- Eliminated the agent parameter from the save method in the Memory class to streamline the code and improve clarity.
- Updated the TraceBatchManager class by moving initialization of attributes into the constructor for better organization and readability.

* refactor: enhance ExecutionState and ReasoningEvent classes with optional task and agent identifiers

- Added optional `current_agent_id` and `current_task_id` attributes to the `ExecutionState` class for better tracking of agent and task states.
- Updated the `from_task` attribute in the `ReasoningEvent` class to use `Optional[Any]` instead of a specific type, improving flexibility in event handling.

* refactor: update ExecutionState class by removing unused agent and task identifiers

- Removed the `current_agent_id` and `current_task_id` attributes from the `ExecutionState` class to simplify the code and enhance clarity.
- Adjusted the import statements to include `Optional` for better type handling.

* refactor: streamline LLM event handling in LiteAgent

- Removed unused LLM event emissions (LLMCallStartedEvent, LLMCallCompletedEvent, LLMCallFailedEvent) from the LiteAgent class to simplify the code and improve performance.
- Adjusted the flow of LLM response handling by eliminating unnecessary event bus interactions, enhancing clarity and maintainability.

* flow ownership and not emitting events when a crew is done

* refactor: remove unused agent parameter from ShortTermMemory save method

- Eliminated the agent parameter from the save method in the ShortTermMemory class to streamline the code and improve clarity.
- This change enhances the maintainability of the memory management system by reducing unnecessary complexity.

* runtype check fix

* fixing tests

* fix lints

* fix: update event assertions in test_llm_emits_event_with_lite_agent

- Adjusted the expected counts for completed and started events in the test to reflect the correct behavior of the LiteAgent.
- Updated assertions for agent roles and IDs to match the expected values after recent changes in event handling.

* fix: update task name assertions in event tests

- Modified assertions in `test_stream_llm_emits_event_with_task_and_agent_info` and `test_llm_emits_event_with_task_and_agent_info` to use `task.description` as a fallback for `task.name`. This ensures that the tests correctly validate the task name even when it is not explicitly set.

* fix: update test assertions for output values and improve readability

- Updated assertions in `test_output_json_dict_hierarchical` to reflect the correct expected score value.
- Enhanced readability of assertions in `test_output_pydantic_to_another_task` and `test_key` by formatting the error messages for clarity.
- These changes ensure that the tests accurately validate the expected outputs and improve overall code quality.

* test fixes

* fix crew_test

* added another fixture

* fix: ensure agent and task assignments in contextual memory are conditional

- Updated the ContextualMemory class to check for the existence of short-term, long-term, external, and extended memory before assigning agent and task attributes. This prevents potential attribute errors when memory types are not initialized.
2025-08-26 09:09:46 -07:00
Greyson LaLonde
4b4a119a9f refactor: simplify rag client initialization (#3401)
* Simplified Qdrant and ChromaDB client initialization
* Refactored factory structure and updated tests accordingly
2025-08-26 08:54:51 -04:00
Greyson LaLonde
869bb115c8 Qdrant RAG Provider Support (#3400)
* Added Qdrant provider support with factory, config, and protocols
* Improved default embeddings and type definitions
* Fixed ChromaDB factory embedding assignment
2025-08-26 08:44:02 -04:00
Greyson LaLonde
7ac482c7c9 feat: rag configuration with optional dependency support (#3394)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
### RAG Config System

* Added ChromaDB client creation via config with sensible defaults
* Introduced optional imports and shared RAG config utilities/schema
* Enabled embedding function support with ChromaDB provider integration
* Refactored configs for immutability and stronger type safety
* Removed unused code and expanded test coverage
2025-08-26 00:00:22 -04:00
Greyson LaLonde
2e4bd3f49d feat: qdrant generic client (#3377)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
### Qdrant Client

* Add core client with collection, search, and document APIs (sync + async)
* Refactor utilities, types, and vector params (default 384-dim)
* Improve error handling with `ClientMethodMismatchError`
* Add score normalization, async embeddings, and optional `qdrant-client` dep
* Expand tests and type safety throughout
2025-08-25 16:02:25 -04:00
Greyson LaLonde
c02997d956 Add import utilities for optional dependencies (#3389)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-08-24 22:57:44 -04:00
Heitor Carvalho
f96b779df5 feat: reset tokens on crewai config reset (#3365)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-08-22 16:16:42 -04:00
Greyson LaLonde
842bed4e9c feat: chromadb generic client (#3374)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Add ChromaDB client implementation with async support

- Implement core collection operations (create, get_or_create, delete)
- Add search functionality with cosine similarity scoring
- Include both sync and async method variants
- Add type safety with NamedTuples and TypeGuards
- Extract utility functions to separate modules
- Default to cosine distance metric for text similarity
- Add comprehensive test coverage

TODO:
- l2, ip score calculations are not settled on
2025-08-21 18:18:46 -04:00
Lucas Gomide
1217935b31 feat: add docs about Automation triggers (#3375)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-08-20 22:02:47 -04:00
Greyson LaLonde
641c156c17 fix: address flaky tests (#3363)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
fix: resolve flaky tests and race conditions in test suite

- Fix telemetry/event tests by patching class methods instead of instances
- Use unique temp files/directories to prevent CI race conditions
- Reset singleton state between tests
- Mock embedchain.Client.setup() to prevent JSON corruption
- Rename test files to test_*.py convention
- Move agent tests to tests/agents directory
- Fix repeated tool usage detection
- Remove database-dependent tools causing initialization errors
2025-08-20 13:34:09 -04:00
Tony Kipkemboi
7fdf9f9290 docs: fix API Reference OpenAPI sources and redirects (#3368)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* docs: fix API Reference OpenAPI sources and redirects; clarify training data usage; add Mermaid diagram; correct CLI usage and notes

* docs(mintlify): use explicit openapi {source, directory} with absolute paths to fix branch deployment routing

* docs(mintlify): add explicit endpoint MDX pages and include in nav; keep OpenAPI auto-gen as fallback

* docs(mintlify): remove OpenAPI Endpoints groups; add localized MDX endpoint pages for pt-BR and ko
2025-08-20 11:55:35 -04:00
Greyson LaLonde
c0d2bf4c12 fix: flow listener resumability for HITL and cyclic flows (#3322)
* fix: flow listener resumability for HITL and cyclic flows

- Add resumption context flag to distinguish HITL resumption from cyclic execution
- Skip method re-execution only during HITL resumption, not for cyclic flows
- Ensure cyclic flows like test_cyclic_flow continue to work correctly

* fix: prevent duplicate execution of conditional start methods in flows

* fix: resolve type error in flow.py line 1040 assignment
2025-08-20 10:06:18 -04:00
Greyson LaLonde
ed187b495b feat: centralize embedding types and create base client (#3246)
feat: add RAG system foundation with generic vector store support

- Add BaseClient protocol for vector stores
- Move BaseRAGStorage to rag/core
- Centralize embedding types in embeddings/types.py
- Remove unused storage models
2025-08-20 09:35:27 -04:00
Wajeeh ul Hassan
2773996b49 fix: revert pin openai<1.100.0 to openai>=1.13.3 (#3364) 2025-08-20 09:16:26 -04:00
Damian Silbergleith
95923b78c6 feat: display task name in verbose output (#3308)
* feat: display task name in verbose output

- Modified event_listener.py to pass task names to the formatter
- Updated console_formatter.py to display task names when available
- Maintains backward compatibility by showing UUID for tasks without names
- Makes verbose output more informative and readable

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: remove unnecessary f-string prefixes in console formatter

Remove extraneous f prefixes from string literals without placeholders
in console_formatter.py to resolve ruff F541 linting errors.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-20 08:43:05 -04:00
Lucas Gomide
7065ad4336 feat: adding additional parameter to Flow' start methods (#3356)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* feat: adding additional parameter to Flow' start methods

When the `crewai_trigger_payload` parameter exists in the input Flow, we will add it in the start Flow methods as parameter

* fix: support crewai_trigger_payload in async Flow start methods
2025-08-19 17:32:19 -04:00
Lorenze Jay
d6254918fd Lorenze/max retry defaults tools (#3362)
* feat: enhance BaseTool and CrewStructuredTool with usage tracking

This commit introduces a mechanism to track the usage count of tools within the CrewAI framework. The `BaseTool` class now includes a `_increment_usage_count` method that updates the current usage count, which is also reflected in the associated `CrewStructuredTool`. Additionally, a new test has been added to ensure that the maximum usage count is respected when invoking tools, enhancing the overall reliability and functionality of the tool system.

* feat: add max usage count feature to tools documentation

This commit introduces a new section in the tools overview documentation that explains the maximum usage count feature for tools within the CrewAI framework. Users can now set a limit on how many times a tool can be used, enhancing control over tool usage. An example of implementing the `FileReadTool` with a maximum usage count is also provided, improving the clarity and usability of the documentation.

* undo field string
2025-08-19 10:44:55 -07:00
Heitor Carvalho
95e3d6db7a fix: add 'tool' section migration when running crewai update (#3341)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
2025-08-19 08:11:30 -04:00
Lorenze Jay
d7f8002baa chore: update crewAI version to 0.165.1 and tools dependency in templates (#3359) (#3359)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-08-19 00:06:31 -03:00
Lorenze Jay
d743e12a06 refactor: streamline tracing condition checks and clean up deprecated warnings (#3358)
This commit simplifies the conditions for enabling tracing in both the Crew and Flow classes by removing the redundant call to `on_first_execution_tracing_confirmation()`. Additionally, it removes deprecated warning filters related to Pydantic in the KnowledgeStorage and RAGStorage classes, improving code clarity and maintainability.
2025-08-18 19:56:00 -07:00
Lorenze Jay
6068fe941f chore: update crewAI version to 0.165.0 and tools dependency to 0.62.1 (#3357) 2025-08-18 18:25:59 -07:00
Lucas Gomide
2a0cefc98b feat: pin openai<1.100.0 due ResponseTextConfigParam import issue (#3355)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
2025-08-18 18:31:18 -04:00
Lucas Gomide
a4f65e4870 chore: renaming inject_trigger_input to allow_crewai_trigger_context (#3353)
* chore: renaming inject_trigger_input to allow_crewai_trigger_context

* test: add missing cassetes
2025-08-18 17:57:21 -04:00
Lorenze Jay
a1b3edd79c Refactor tracing logic to consolidate conditions for enabling tracing… (#3347)
* Refactor tracing logic to consolidate conditions for enabling tracing in Crew class and update TraceBatchManager to handle ephemeral batches more effectively. Added tests for trace listener handling of both ephemeral and authenticated user batches.

* drop print

* linted

* refactor: streamline ephemeral handling in TraceBatchManager

This commit removes the ephemeral parameter from the _send_events_to_backend and _finalize_backend_batch methods, replacing it with internal logic that checks the current batch's ephemeral status. This change simplifies the method signatures and enhances the clarity of the code by directly using the is_current_batch_ephemeral attribute for conditional logic.
2025-08-18 14:16:51 -07:00
Lucas Gomide
80b3d9689a Auto inject crewai_trigger_payload (#3351)
* feat: add props to inject trigger payload

* feat: auto-inject trigger_input in the first crew task
2025-08-18 16:36:08 -04:00
Vini Brasil
ec03a53121 Add example to Tool Repository docs (#3352) 2025-08-18 13:19:35 -07:00
Vini Brasil
2fdf3f3a6a Move Chroma lockfile to db/ (#3342)
This commit fixes an issue where using Chroma would spam lockfiles over
the root path of the crew.
2025-08-18 11:00:50 -07:00
Greyson LaLonde
1d3d7ebf5e fix: convert XMLSearchTool config values to strings for configparser compatibility (#3344) 2025-08-18 13:23:58 -04:00
Gabe Milani
2c2196f415 fix: flaky test with PytestUnraisableExceptionWarning (#3346) 2025-08-18 14:07:51 -03:00
Gabe Milani
c9f30b175c chore: ignore deprecation warning from chromadb (#3328)
* chore: ignore deprecation warning from chromadb

* adding TODO: in the comment
2025-08-18 13:24:11 -03:00
Greyson LaLonde
a17b93a7f8 Mock telemetry in pytest tests (#3340)
* Add telemetry mocking for pytest tests

- Mock telemetry by default for all tests except telemetry-specific tests
- Add @pytest.mark.telemetry marker for real telemetry tests
- Reduce test overhead and improve isolation

* Fix telemetry test isolation

- Properly isolate telemetry tests from mocking environment
- Preserve API keys and other necessary environment variables
- Ensure telemetry tests can run with real telemetry instances
2025-08-18 11:55:30 -04:00
namho kim
0d3e462791 fix: Revised Korean translation and sentence structure improvement (#3337)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
2025-08-18 10:46:13 -04:00
181 changed files with 16949 additions and 1283 deletions

View File

@@ -320,6 +320,7 @@
"en/enterprise/guides/update-crew",
"en/enterprise/guides/enable-crew-studio",
"en/enterprise/guides/azure-openai-setup",
"en/enterprise/guides/automation-triggers",
"en/enterprise/guides/hubspot-trigger",
"en/enterprise/guides/react-component-export",
"en/enterprise/guides/salesforce-trigger",
@@ -341,11 +342,12 @@
"groups": [
{
"group": "Getting Started",
"pages": ["en/api-reference/introduction"]
},
{
"group": "Endpoints",
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.en.yaml"
"pages": [
"en/api-reference/introduction",
"en/api-reference/inputs",
"en/api-reference/kickoff",
"en/api-reference/status"
]
}
]
},
@@ -657,6 +659,7 @@
"pt-BR/enterprise/guides/update-crew",
"pt-BR/enterprise/guides/enable-crew-studio",
"pt-BR/enterprise/guides/azure-openai-setup",
"pt-BR/enterprise/guides/automation-triggers",
"pt-BR/enterprise/guides/hubspot-trigger",
"pt-BR/enterprise/guides/react-component-export",
"pt-BR/enterprise/guides/salesforce-trigger",
@@ -680,11 +683,12 @@
"groups": [
{
"group": "Começando",
"pages": ["pt-BR/api-reference/introduction"]
},
{
"group": "Endpoints",
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.pt-BR.yaml"
"pages": [
"pt-BR/api-reference/introduction",
"pt-BR/api-reference/inputs",
"pt-BR/api-reference/kickoff",
"pt-BR/api-reference/status"
]
}
]
},
@@ -709,7 +713,7 @@
"icon": "globe"
},
{
"anchor": "법정",
"anchor": "포럼",
"href": "https://community.crewai.com",
"icon": "discourse"
},
@@ -719,7 +723,7 @@
"icon": "robot"
},
{
"anchor": "출시",
"anchor": "릴리스",
"href": "https://github.com/crewAIInc/crewAI/releases",
"icon": "tag"
}
@@ -734,22 +738,22 @@
"pages": ["ko/introduction", "ko/installation", "ko/quickstart"]
},
{
"group": "안내서",
"group": "가이드",
"pages": [
{
"group": "전략",
"pages": ["ko/guides/concepts/evaluating-use-cases"]
},
{
"group": "Agents",
"group": "에이전트 (Agents)",
"pages": ["ko/guides/agents/crafting-effective-agents"]
},
{
"group": "Crews",
"group": "크루 (Crews)",
"pages": ["ko/guides/crews/first-crew"]
},
{
"group": "Flows",
"group": "플로우 (Flows)",
"pages": [
"ko/guides/flows/first-flow",
"ko/guides/flows/mastering-flow-state"
@@ -797,7 +801,7 @@
]
},
{
"group": "도구",
"group": "도구 (Tools)",
"pages": [
"ko/tools/overview",
{
@@ -887,7 +891,7 @@
]
},
{
"group": "클라우드 & 저장",
"group": "클라우드 & 스토리지",
"pages": [
"ko/tools/cloud-storage/overview",
"ko/tools/cloud-storage/s3readertool",
@@ -909,7 +913,7 @@
]
},
{
"group": "오브저버빌리티",
"group": "Observability",
"pages": [
"ko/observability/overview",
"ko/observability/arize-phoenix",
@@ -927,7 +931,7 @@
]
},
{
"group": "익히다",
"group": "학습",
"pages": [
"ko/learn/overview",
"ko/learn/llm-selection-guide",
@@ -951,13 +955,13 @@
]
},
{
"group": "원격측정",
"group": "Telemetry",
"pages": ["ko/telemetry"]
}
]
},
{
"tab": "기업",
"tab": "엔터프라이즈",
"groups": [
{
"group": "시작 안내",
@@ -997,7 +1001,7 @@
]
},
{
"group": "사용 안내서",
"group": "How-To Guides",
"pages": [
"ko/enterprise/guides/build-crew",
"ko/enterprise/guides/deploy-crew",
@@ -1005,6 +1009,7 @@
"ko/enterprise/guides/update-crew",
"ko/enterprise/guides/enable-crew-studio",
"ko/enterprise/guides/azure-openai-setup",
"ko/enterprise/guides/automation-triggers",
"ko/enterprise/guides/hubspot-trigger",
"ko/enterprise/guides/react-component-export",
"ko/enterprise/guides/salesforce-trigger",
@@ -1026,11 +1031,12 @@
"groups": [
{
"group": "시작 안내",
"pages": ["ko/api-reference/introduction"]
},
{
"group": "Endpoints",
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.ko.yaml"
"pages": [
"ko/api-reference/introduction",
"ko/api-reference/inputs",
"ko/api-reference/kickoff",
"ko/api-reference/status"
]
}
]
},
@@ -1081,6 +1087,10 @@
"indexing": "all"
},
"redirects": [
{
"source": "/api-reference",
"destination": "/en/api-reference/introduction"
},
{
"source": "/introduction",
"destination": "/en/introduction"
@@ -1133,6 +1143,18 @@
"source": "/api-reference/:path*",
"destination": "/en/api-reference/:path*"
},
{
"source": "/en/api-reference",
"destination": "/en/api-reference/introduction"
},
{
"source": "/pt-BR/api-reference",
"destination": "/pt-BR/api-reference/introduction"
},
{
"source": "/ko/api-reference",
"destination": "/ko/api-reference/introduction"
},
{
"source": "/examples/:path*",
"destination": "/en/examples/:path*"

View File

@@ -0,0 +1,7 @@
---
title: "GET /inputs"
description: "Get required inputs for your crew"
openapi: "/enterprise-api.en.yaml GET /inputs"
---

View File

@@ -0,0 +1,7 @@
---
title: "POST /kickoff"
description: "Start a crew execution"
openapi: "/enterprise-api.en.yaml POST /kickoff"
---

View File

@@ -0,0 +1,7 @@
---
title: "GET /status/{kickoff_id}"
description: "Get execution status"
openapi: "/enterprise-api.en.yaml GET /status/{kickoff_id}"
---

View File

@@ -59,6 +59,7 @@ crew = Crew(
| **Output Pydantic** _(optional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | A Pydantic model for task output. |
| **Callback** _(optional)_ | `callback` | `Optional[Any]` | Function/object to be executed after task completion. |
| **Guardrail** _(optional)_ | `guardrail` | `Optional[Callable]` | Function to validate task output before proceeding to next task. |
| **Guardrail Max Retries** _(optional)_ | `guardrail_max_retries` | `Optional[int]` | Maximum number of retries when guardrail validation fails. Defaults to 3. |
## Creating Tasks
@@ -452,7 +453,7 @@ task = Task(
expected_output="A valid JSON object",
agent=analyst,
guardrail=validate_json_output,
max_retries=3 # Limit retry attempts
guardrail_max_retries=3 # Limit retry attempts
)
```

View File

@@ -21,13 +21,17 @@ To use the training feature, follow these steps:
3. Run the following command:
```shell
crewai train -n <n_iterations> <filename> (optional)
crewai train -n <n_iterations> -f <filename.pkl>
```
<Tip>
Replace `<n_iterations>` with the desired number of training iterations and `<filename>` with the appropriate filename ending with `.pkl`.
</Tip>
### Training Your Crew Programmatically
<Note>
If you omit `-f`, the output defaults to `trained_agents_data.pkl` in the current working directory. You can pass an absolute path to control where the file is written.
</Note>
### Training your Crew programmatically
To train your crew programmatically, use the following steps:
@@ -51,19 +55,65 @@ except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")
```
### Key Points to Note
## How trained data is used by agents
- **Positive Integer Requirement:** Ensure that the number of iterations (`n_iterations`) is a positive integer. The code will raise a `ValueError` if this condition is not met.
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
CrewAI uses the training artifacts in two ways: during training to incorporate your human feedback, and after training to guide agents with consolidated suggestions.
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.
### Training data flow
Once the training is complete, your agents will be equipped with enhanced capabilities and knowledge, ready to tackle complex tasks and provide more consistent and valuable insights.
```mermaid
flowchart TD
A["Start training<br/>CLI: crewai train -n -f<br/>or Python: crew.train(...)"] --> B["Setup training mode<br/>- task.human_input = true<br/>- disable delegation<br/>- init training_data.pkl + trained file"]
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.
subgraph "Iterations"
direction LR
C["Iteration i<br/>initial_output"] --> D["User human_feedback"]
D --> E["improved_output"]
E --> F["Append to training_data.pkl<br/>by agent_id and iteration"]
end
Happy training with CrewAI! 🚀
B --> C
F --> G{"More iterations?"}
G -- "Yes" --> C
G -- "No" --> H["Evaluate per agent<br/>aggregate iterations"]
H --> I["Consolidate<br/>suggestions[] + quality + final_summary"]
I --> J["Save by agent role to trained file<br/>(default: trained_agents_data.pkl)"]
J --> K["Normal (non-training) runs"]
K --> L["Auto-load suggestions<br/>from trained_agents_data.pkl"]
L --> M["Append to prompt<br/>for consistent improvements"]
```
### During training runs
- On each iteration, the system records for every agent:
- `initial_output`: the agents first answer
- `human_feedback`: your inline feedback when prompted
- `improved_output`: the agents follow-up answer after feedback
- This data is stored in a working file named `training_data.pkl` keyed by the agents internal ID and iteration.
- While training is active, the agent automatically appends your prior human feedback to its prompt to enforce those instructions on subsequent attempts within the training session.
Training is interactive: tasks set `human_input = true`, so running in a non-interactive environment will block on user input.
### After training completes
- When `train(...)` finishes, CrewAI evaluates the collected training data per agent and produces a consolidated result containing:
- `suggestions`: clear, actionable instructions distilled from your feedback and the difference between initial/improved outputs
- `quality`: a 010 score capturing improvement
- `final_summary`: a step-by-step set of action items for future tasks
- These consolidated results are saved to the filename you pass to `train(...)` (default via CLI is `trained_agents_data.pkl`). Entries are keyed by the agents `role` so they can be applied across sessions.
- During normal (non-training) execution, each agent automatically loads its consolidated `suggestions` and appends them to the task prompt as mandatory instructions. This gives you consistent improvements without changing your agent definitions.
### File summary
- `training_data.pkl` (ephemeral, per-session):
- Structure: `agent_id -> { iteration_number: { initial_output, human_feedback, improved_output } }`
- Purpose: capture raw data and human feedback during training
- Location: saved in the current working directory (CWD)
- `trained_agents_data.pkl` (or your custom filename):
- Structure: `agent_role -> { suggestions: string[], quality: number, final_summary: string }`
- Purpose: persist consolidated guidance for future runs
- Location: written to the CWD by default; use `-f` to set a custom (including absolute) path
## Small Language Model Considerations
@@ -129,3 +179,18 @@ Happy training with CrewAI! 🚀
</Warning>
</Tab>
</Tabs>
### Key Points to Note
- **Positive Integer Requirement:** Ensure that the number of iterations (`n_iterations`) is a positive integer. The code will raise a `ValueError` if this condition is not met.
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
- Trained guidance is applied at prompt time; it does not modify your Python/YAML agent configuration.
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, either rename it to `trained_agents_data.pkl` before running, or adjust the loader in code.
- You can change the output filename when calling `crewai train` with `-f/--filename`. Absolute paths are supported if you want to save outside the CWD.
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.
Once the training is complete, your agents will be equipped with enhanced capabilities and knowledge, ready to tackle complex tasks and provide more consistent and valuable insights.
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.

View File

@@ -35,6 +35,22 @@ crewai tool install <tool-name>
This installs the tool and adds it to `pyproject.toml`.
You can use the tool by importing it and adding it to your agents:
```python
from your_tool.tool import YourTool
custom_tool = YourTool()
researcher = Agent(
role='Market Research Analyst',
goal='Provide up-to-date market analysis of the AI industry',
backstory='An expert analyst with a keen eye for market trends.',
tools=[custom_tool],
verbose=True
)
```
## Creating and Publishing Tools
To create a new tool project:

View File

@@ -0,0 +1,171 @@
---
title: "Automation Triggers"
description: "Automatically execute your CrewAI workflows when specific events occur in connected integrations"
icon: "bolt"
---
Automation triggers enable you to automatically run your CrewAI deployments when specific events occur in your connected integrations, creating powerful event-driven workflows that respond to real-time changes in your business systems.
## Overview
With automation triggers, you can:
- **Respond to real-time events** - Automatically execute workflows when specific conditions are met
- **Integrate with external systems** - Connect with platforms like Gmail, Outlook, OneDrive, JIRA, Slack, Stripe and more
- **Scale your automation** - Handle high-volume events without manual intervention
- **Maintain context** - Access trigger data within your crews and flows
## Managing Automation Triggers
### Viewing Available Triggers
To access and manage your automation triggers:
1. Navigate to your deployment in the CrewAI dashboard
2. Click on the **Triggers** tab to view all available trigger integrations
<Frame>
<img src="/images/enterprise/list-available-triggers.png" alt="List of available automation triggers" />
</Frame>
This view shows all the trigger integrations available for your deployment, along with their current connection status.
### Enabling and Disabling Triggers
Each trigger can be easily enabled or disabled using the toggle switch:
<Frame>
<img src="/images/enterprise/trigger-selected.png" alt="Enable or disable triggers with toggle" />
</Frame>
- **Enabled (blue toggle)**: The trigger is active and will automatically execute your deployment when the specified events occur
- **Disabled (gray toggle)**: The trigger is inactive and will not respond to events
Simply click the toggle to change the trigger state. Changes take effect immediately.
### Monitoring Trigger Executions
Track the performance and history of your triggered executions:
<Frame>
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
</Frame>
## Building Automation
Before building your automation, it's helpful to understand the structure of trigger payloads that your crews and flows will receive.
### Payload Samples Repository
We maintain a comprehensive repository with sample payloads from various trigger sources to help you build and test your automations:
**🔗 [CrewAI Enterprise Trigger Payload Samples](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
This repository contains:
- **Real payload examples** from different trigger sources (Gmail, Google Drive, etc.)
- **Payload structure documentation** showing the format and available fields
### Triggers with Crew
Your existing crew definitions work seamlessly with triggers, you just need to have a task to parse the received payload:
```python
@CrewBase
class MyAutomatedCrew:
@agent
def researcher(self) -> Agent:
return Agent(
config=self.agents_config['researcher'],
)
@task
def parse_trigger_payload(self) -> Task:
return Task(
config=self.tasks_config['parse_trigger_payload'],
agent=self.researcher(),
)
@task
def analyze_trigger_content(self) -> Task:
return Task(
config=self.tasks_config['analyze_trigger_data'],
agent=self.researcher(),
)
```
The crew will automatically receive and can access the trigger payload through the standard CrewAI context mechanisms.
### Integration with Flows
For flows, you have more control over how trigger data is handled:
#### Accessing Trigger Payload
All `@start()` methods in your flows will accept an additional parameter called `crewai_trigger_payload`:
```python
from crewai.flow import Flow, start, listen
class MyAutomatedFlow(Flow):
@start()
def handle_trigger(self, crewai_trigger_payload: dict = None):
"""
This start method can receive trigger data
"""
if crewai_trigger_payload:
# Process the trigger data
trigger_id = crewai_trigger_payload.get('id')
event_data = crewai_trigger_payload.get('payload', {})
# Store in flow state for use by other methods
self.state.trigger_id = trigger_id
self.state.trigger_type = event_data
return event_data
# Handle manual execution
return None
@listen(handle_trigger)
def process_data(self, trigger_data):
"""
Process the data from the trigger
"""
# ... process the trigger
```
#### Triggering Crews from Flows
When kicking off a crew within a flow that was triggered, pass the trigger payload as it:
```python
@start()
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
"""
Delegate processing to a specialized crew
"""
crew = MySpecializedCrew()
# Pass the trigger payload to the crew
result = crew.crew().kickoff(
inputs={
'a_custom_parameter': "custom_value",
'crewai_trigger_payload': crewai_trigger_payload
},
)
return result
```
## Troubleshooting
**Trigger not firing:**
- Verify the trigger is enabled
- Check integration connection status
**Execution failures:**
- Check the execution logs for error details
- If you are developing, make sure the inputs include the `crewai_trigger_payload` parameter with the correct payload
Automation triggers transform your CrewAI deployments into responsive, event-driven systems that can seamlessly integrate with your existing business processes and tools.

View File

@@ -117,4 +117,19 @@ agent = Agent(
)
```
## **Max Usage Count**
You can set a maximum usage count for a tool to prevent it from being used more than a certain number of times.
By default, the max usage count is unlimited.
```python
from crewai_tools import FileReadTool
tool = FileReadTool(max_usage_count=5, ...)
```
Ready to explore? Pick a category above to discover tools that fit your use case!

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 330 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

View File

@@ -0,0 +1,7 @@
---
title: "GET /inputs"
description: "크루가 필요로 하는 입력 확인"
openapi: "/enterprise-api.ko.yaml GET /inputs"
---

View File

@@ -0,0 +1,7 @@
---
title: "POST /kickoff"
description: "크루 실행 시작"
openapi: "/enterprise-api.ko.yaml POST /kickoff"
---

View File

@@ -0,0 +1,7 @@
---
title: "GET /status/{kickoff_id}"
description: "실행 상태 조회"
openapi: "/enterprise-api.ko.yaml GET /status/{kickoff_id}"
---

View File

@@ -59,6 +59,7 @@ crew = Crew(
| **Pydantic 출력** _(선택 사항)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | 태스크 출력용 Pydantic 모델입니다. |
| **콜백** _(선택 사항)_ | `callback` | `Optional[Any]` | 태스크 완료 후 실행할 함수/객체입니다. |
| **가드레일** _(선택 사항)_ | `guardrail` | `Optional[Callable]` | 다음 태스크로 진행하기 전에 태스크 출력을 검증하는 함수입니다. |
| **가드레일 최대 재시도** _(선택 사항)_ | `guardrail_max_retries` | `Optional[int]` | 가드레일 검증 실패 시 최대 재시도 횟수입니다. 기본값은 3입니다. |
## 작업 생성하기
@@ -448,7 +449,7 @@ task = Task(
expected_output="A valid JSON object",
agent=analyst,
guardrail=validate_json_output,
max_retries=3 # Limit retry attempts
guardrail_max_retries=3 # 재시도 횟수 제한
)
```
@@ -899,4 +900,4 @@ except RuntimeError as e:
작업(task)은 CrewAI 에이전트의 행동을 이끄는 원동력입니다.
작업과 그 결과를 적절하게 정의함으로써, 에이전트가 독립적으로 또는 협업 단위로 효과적으로 작동할 수 있는 기반을 마련할 수 있습니다.
작업에 적합한 도구를 장착하고, 실행 과정을 이해하며, 견고한 검증 절차를 따르는 것은 CrewAI의 잠재력을 극대화하는 데 필수적입니다.
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.

View File

@@ -0,0 +1,171 @@
---
title: "자동화 트리거"
description: "연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 워크플로우를 자동으로 실행합니다"
icon: "bolt"
---
자동화 트리거를 사용하면 연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 배포를 자동으로 실행할 수 있어, 비즈니스 시스템의 실시간 변화에 반응하는 강력한 이벤트 기반 워크플로우를 만들 수 있습니다.
## 개요
자동화 트리거를 사용하면 다음을 수행할 수 있습니다:
- **실시간 이벤트에 응답** - 특정 조건이 충족될 때 워크플로우를 자동으로 실행
- **외부 시스템과 통합** - Gmail, Outlook, OneDrive, JIRA, Slack, Stripe 등의 플랫폼과 연결
- **자동화 확장** - 수동 개입 없이 대용량 이벤트 처리
- **컨텍스트 유지** - crew와 flow 내에서 트리거 데이터에 액세스
## 자동화 트리거 관리
### 사용 가능한 트리거 보기
자동화 트리거에 액세스하고 관리하려면:
1. CrewAI 대시보드에서 배포로 이동
2. **트리거** 탭을 클릭하여 사용 가능한 모든 트리거 통합 보기
<Frame>
<img src="/images/enterprise/list-available-triggers.png" alt="사용 가능한 자동화 트리거 목록" />
</Frame>
이 보기는 배포에 사용 가능한 모든 트리거 통합과 현재 연결 상태를 보여줍니다.
### 트리거 활성화 및 비활성화
각 트리거는 토글 스위치를 사용하여 쉽게 활성화하거나 비활성화할 수 있습니다:
<Frame>
<img src="/images/enterprise/trigger-selected.png" alt="토글로 트리거 활성화 또는 비활성화" />
</Frame>
- **활성화됨 (파란색 토글)**: 트리거가 활성 상태이며 지정된 이벤트가 발생할 때 배포를 자동으로 실행합니다
- **비활성화됨 (회색 토글)**: 트리거가 비활성 상태이며 이벤트에 응답하지 않습니다
토글을 클릭하기만 하면 트리거 상태를 변경할 수 있습니다. 변경 사항은 즉시 적용됩니다.
### 트리거 실행 모니터링
트리거된 실행의 성능과 기록을 추적합니다:
<Frame>
<img src="/images/enterprise/list-executions.png" alt="자동화에 의해 트리거된 실행 목록" />
</Frame>
## 자동화 구축
자동화를 구축하기 전에 crew와 flow가 받을 트리거 페이로드의 구조를 이해하는 것이 도움이 됩니다.
### 페이로드 샘플 저장소
자동화를 구축하고 테스트하는 데 도움이 되도록 다양한 트리거 소스의 샘플 페이로드가 포함된 포괄적인 저장소를 유지 관리하고 있습니다:
**🔗 [CrewAI Enterprise 트리거 페이로드 샘플](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
이 저장소에는 다음이 포함되어 있습니다:
- **실제 페이로드 예제** - 다양한 트리거 소스(Gmail, Google Drive 등)에서 가져온 예제
- **페이로드 구조 문서** - 형식과 사용 가능한 필드를 보여주는 문서
### Crew와 트리거
기존 crew 정의는 트리거와 완벽하게 작동하며, 받은 페이로드를 분석하는 작업만 있으면 됩니다:
```python
@CrewBase
class MyAutomatedCrew:
@agent
def researcher(self) -> Agent:
return Agent(
config=self.agents_config['researcher'],
)
@task
def parse_trigger_payload(self) -> Task:
return Task(
config=self.tasks_config['parse_trigger_payload'],
agent=self.researcher(),
)
@task
def analyze_trigger_content(self) -> Task:
return Task(
config=self.tasks_config['analyze_trigger_data'],
agent=self.researcher(),
)
```
crew는 자동으로 트리거 페이로드를 받고 표준 CrewAI 컨텍스트 메커니즘을 통해 액세스할 수 있습니다.
### Flow와의 통합
flow의 경우 트리거 데이터 처리 방법을 더 세밀하게 제어할 수 있습니다:
#### 트리거 페이로드 액세스
flow의 모든 `@start()` 메서드는 `crewai_trigger_payload`라는 추가 매개변수를 허용합니다:
```python
from crewai.flow import Flow, start, listen
class MyAutomatedFlow(Flow):
@start()
def handle_trigger(self, crewai_trigger_payload: dict = None):
"""
이 start 메서드는 트리거 데이터를 받을 수 있습니다
"""
if crewai_trigger_payload:
# 트리거 데이터 처리
trigger_id = crewai_trigger_payload.get('id')
event_data = crewai_trigger_payload.get('payload', {})
# 다른 메서드에서 사용할 수 있도록 flow 상태에 저장
self.state.trigger_id = trigger_id
self.state.trigger_type = event_data
return event_data
# 수동 실행 처리
return None
@listen(handle_trigger)
def process_data(self, trigger_data):
"""
트리거 데이터 처리
"""
# ... 트리거 처리
```
#### Flow에서 Crew 트리거하기
트리거된 flow 내에서 crew를 시작할 때 트리거 페이로드를 전달합니다:
```python
@start()
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
"""
전문 crew에 처리 위임
"""
crew = MySpecializedCrew()
# crew에 트리거 페이로드 전달
result = crew.crew().kickoff(
inputs={
'a_custom_parameter': "custom_value",
'crewai_trigger_payload': crewai_trigger_payload
},
)
return result
```
## 문제 해결
**트리거가 작동하지 않는 경우:**
- 트리거가 활성화되어 있는지 확인
- 통합 연결 상태 확인
**실행 실패:**
- 오류 세부 정보는 실행 로그 확인
- 개발 중인 경우 입력에 올바른 페이로드가 포함된 `crewai_trigger_payload` 매개변수가 포함되어 있는지 확인
자동화 트리거는 CrewAI 배포를 기존 비즈니스 프로세스 및 도구와 완벽하게 통합할 수 있는 반응형 이벤트 기반 시스템으로 변환합니다.

View File

@@ -1,65 +1,65 @@
---
title: 소개
description: 함께 협력하여 복잡한 작업을 해결하는 AI 에이전트 팀 구축
description: 함께 협력하여 복잡한 작업을 해결하는 AI agent 팀 구축
icon: handshake
---
# CrewAI란 무엇인가?
**CrewAI는 완전히 독립적으로, LangChain이나 기타 agent 프레임워크에 의존하지 않고 처음부터 스크래치로 개발된 가볍고 매우 빠른 Python 프레임워크입니다.**
**CrewAI는 LangChain이나 기타 agent 프레임워크에 의존하지 않고, 완전히 독립적으로 처음부터 스크래치로 개발된 가볍고 매우 빠른 Python 프레임워크입니다.**
CrewAI는 고수준의 간편함과 정밀한 저수준 제어를 모두 제공하여, 어떤 시나리오에도 맞춤화된 자율 AI agent를 만드는 데 이상적입니다:
- **[CrewAI Crews](/ko/guides/crews/first-crew)**: 자율성과 협업 지능을 극대화하여, 각 agent가 특정 역할, 도구, 목표를 가진 AI 팀을 만들 수 있습니다.
- **[CrewAI Flows](/ko/guides/flows/first-flow)**: 세밀한 이벤트 기반 제어와 단일 LLM 호출을 통한 정확한 작업 오케스트레이션을 가능하게 하며 Crews 네이티브로 지원합니다.
- **[CrewAI Flows](/ko/guides/flows/first-flow)**: 이벤트 기반의 세밀한 제어와 단일 LLM 호출을 통한 정확한 작업 orchestration을 지원하며, Crews 네이티브로 통합됩니다.
10만 명이 넘는 개발자가 커뮤니티 과정을 통해 인증을 받았으며, CrewAI는 기업용 AI 자동화의 표준으로 빠르게 자리잡고 있습니다.
## 크루 작동 방식
## Crew의 작동 방식
<Note>
회사가 비즈니스 목표를 달성하기 위해 여러 부서(영업, 엔지니어링, 마케팅 등)가 리더십 아래에서 함께 일하는 것처럼, CrewAI는 복잡한 작업을 달성하기 위해 전문화된 역할의 AI 에이전트들이 협력하는 조직을 만들 수 있도록 도와줍니다.
회사가 비즈니스 목표를 달성하기 위해 여러 부서(영업, 엔지니어링, 마케팅 등)가 리더십 아래에서 함께 일하는 것처럼, CrewAI는 복잡한 작업을 달성하기 위해 전문화된 역할의 AI agent들이 협력하는 조직을 만들 수 있도록 도와줍니다.
</Note>
<Frame caption="CrewAI 프레임워크 개요">
<Frame caption="CrewAI Framework Overview">
<img src="/images/crews.png" alt="CrewAI Framework Overview" />
</Frame>
| 구성 요소 | 설명 | 주요 특징 |
|:--------------|:---------------------:|:----------|
| **크루** | 최상위 조직 | • AI 에이전트 팀 관리<br/>• 워크플로우 감독<br/>• 협업 보장<br/>• 결과 전달 |
| **AI 에이전트** | 전문 팀원 | • 특정 역할 보유(연구원, 작가 등)<br/>• 지정된 도구 사용<br/>• 작업 위임 가능<br/>• 자율적 의사결정 가능 |
| **프로세스** | 워크플로우 관리 시스템 | • 협업 패턴 정의<br/>• 작업 할당 제어<br/>• 상호작용 관리<br/>• 효율적 실행 보장 |
| **작업** | 개별 할당 | • 명확한 목표 보유<br/>• 특정 도구 사용<br/>• 더 큰 프로세스에 기여<br/>• 실행 가능한 결과 도출 |
| 구성 요소 | 설명 | 주요 특징 |
|:----------|:----:|:----------|
| **Crew** | 최상위 조직 | • AI agent 팀 관리<br/>• workflow 감독<br/>• 협업 보장<br/>• 결과 전달 |
| **AI agents** | 전문 팀원 | • 특정 역할 보유(Researcher, Writer 등)<br/>• 지정된 도구 사용<br/>• 작업 위임 가능<br/>• 자율적 의사결정 가능 |
| **Process** | workflow 관리 시스템 | • 협업 패턴 정의<br/>• 작업 할당 제어<br/>• 상호작용 관리<br/>• 효율적 실행 보장 |
| **Task** | 개별 할당 | • 명확한 목표 보유<br/>• 특정 도구 사용<br/>• 더 큰 프로세스에 기여<br/>• 실행 가능한 결과 도출 |
### 어떻게 모두 함께 작동하는가
### 전체 구조의 동작 방식
1. **Crew**가 전체 운영을 조직합니다
2. **AI Agents**가 자신들의 전문 작업을 수행합니다
2. **AI agents**가 자신들의 전문 작업을 수행합니다
3. **Process**가 원활한 협업을 보장합니다
4. **Tasks**가 완료되어 목표를 달성합니다
## 주요 기능
<CardGroup cols={2}>
<Card title="역할 기반 에이전트" icon="users">
연구원, 분석가, 작가 등 다양한 역할, 전문성, 목표를 가진 전문 에이전트를 생성할 수 있습니다
<Card title="역할 기반 agent" icon="users">
Researcher, Analyst, Writer 등 다양한 역할 전문성, 목표를 가진 agent를 생성할 수 있습니다
</Card>
<Card title="유연한 도구" icon="screwdriver-wrench">
에이전트에게 외부 서비스 및 데이터 소스와 상호작용할 수 있는 맞춤형 도구와 API를 제공합니다
agent에게 외부 서비스 및 데이터 소스와 상호작용할 수 있는 맞춤형 도구와 API를 제공합니다
</Card>
<Card title="지능형 협업" icon="people-arrows">
에이전트가 함께 작업하며, 인사이트를 공유하고 작업을 조율하여 복잡한 목표를 달성합니다
agent들이 함께 작업하며, 인사이트를 공유하고 작업을 조율하여 복잡한 목표를 달성합니다
</Card>
<Card title="작업 관리" icon="list-check">
순차적 또는 병렬 워크플로우를 정의할 수 있으며, 에이전트가 작업 의존성을 자동으로 처리합니다
순차적 또는 병렬 workflow를 정의할 수 있으며, agent가 작업 의존성을 자동으로 처리합니다
</Card>
</CardGroup>
## 플로우의 작동 원리
## Flow의 작동 원리
<Note>
crew 자율 협업에 탁월한 반면, 플로우는 구조화된 자동화를 제공하여 워크플로우 실행에 대한 세밀한 제어를 제공합니다. 플로우는 조건부 로직, 반복문, 동적 상태 관리를 정확하게 처리하면서 작업이 신뢰성 있게, 안전하게, 효율적으로 실행되도록 보장합니다. 플로우crew와 원활하게 통합되어 높은 자율성과 엄격한 제어의 균형을 이룰 수 있게 해줍니다.
Crew 자율 협업에 탁월하다면, Flow는 구조화된 자동화를 제공하여 workflow 실행에 대한 세밀한 제어를 제공합니다. Flow는 조건부 로직, 반복문, 동적 상태 관리를 정확하게 처리하면서 작업이 신뢰성 있게, 안전하게, 효율적으로 실행되도록 보장합니다. FlowCrew와 원활하게 통합되어 높은 자율성과 엄격한 제어의 균형을 이룰 수 있게 해줍니다.
</Note>
<Frame caption="CrewAI Framework Overview">
@@ -68,41 +68,41 @@ CrewAI는 고수준의 간편함과 정밀한 저수준 제어를 모두 제공
| 구성 요소 | 설명 | 주요 기능 |
|:----------|:-----------:|:------------|
| **Flow** | 구조화된 워크플로우 오케스트레이션 | • 실행 경로 관리<br/>• 상태 전환 처리<br/>• 작업 순서 제어<br/>• 신뢰성 있는 실행 보장 |
| **Events** | 워크플로우 액션 트리거 | • 특정 프로세스 시작<br/>• 동적 응답 가능<br/>• 조건부 분기 지원<br/>• 실시간 적응 허용 |
| **States** | 워크플로우 실행 컨텍스트 | • 실행 데이터 유지<br/>• 데이터 영속성 지원<br/>• 재개 가능성 보장<br/>• 실행 무결성 확보 |
| **Crew Support** | 워크플로우 자동화 강화 | • 필요할 때 agency 삽입<br/>• 구조화된 워크플로우 보완<br/>• 자동화와 인텔리전스의 균형<br/>• 적응적 의사결정 지원 |
| **Flow** | 구조화된 workflow orchestration | • 실행 경로 관리<br/>• 상태 전환 처리<br/>• 작업 순서 제어<br/>• 신뢰성 있는 실행 보장 |
| **Events** | workflow 액션 트리거 | • 특정 프로세스 시작<br/>• 동적 응답 가능<br/>• 조건부 분기 지원<br/>• 실시간 적응 허용 |
| **States** | workflow 실행 컨텍스트 | • 실행 데이터 유지<br/>• 데이터 영속성 지원<br/>• 재개 가능성 보장<br/>• 실행 무결성 확보 |
| **Crew Support** | workflow 자동화 강화 | • 필요할 때 agency 삽입<br/>• 구조화된 workflow 보완<br/>• 자동화와 인텔리전스의 균형<br/>• 적응적 의사결정 지원 |
### 주요 기능
<CardGroup cols={2}>
<Card title="이벤트 기반 오케스트레이션" icon="bolt">
이벤트에 동적으로 반응하여 정밀한 실행 경로 정의
<Card title="이벤트 기반 orchestration" icon="bolt">
이벤트에 동적으로 반응하여 정밀한 실행 경로 정의합니다
</Card>
<Card title="세밀한 제어" icon="sliders">
워크플로우 상태와 조건부 실행을 안전하고 효율적으로 관리
workflow 상태와 조건부 실행을 안전하고 효율적으로 관리합니다
</Card>
<Card title="네이티브 Crew 통합" icon="puzzle-piece">
Crews와 손쉽게 결합하여 자율성과 지능 강화
Crews와 손쉽게 결합하여 자율성과 지능 강화합니다
</Card>
<Card title="결정론적 실행" icon="route">
명시적 제어 흐름과 오류 처리로 예측 가능한 결과 보장
명시적 제어 흐름과 오류 처리로 예측 가능한 결과 보장합니다
</Card>
</CardGroup>
## 크루(Crews)와 플로우(Flows)를 언제 사용할까
## CrewFlow를 언제 사용할까
<Note>
[크루](/ko/guides/crews/first-crew)와 [플로우](/ko/guides/flows/first-flow)를 언제 사용할지 이해하는 것은 CrewAI의 잠재력을 애플리케이션에서 극대화하는 데 핵심적입니다.
[Crew](/ko/guides/crews/first-crew)와 [Flow](/ko/guides/flows/first-flow)를 언제 사용할지 이해하는 것은 CrewAI의 잠재력을 애플리케이션에서 극대화하는 데 핵심적입니다.
</Note>
| 사용 사례 | 권장 접근 방식 | 이유 |
|:---------|:---------------------|:-----|
| **개방형 연구** | [크루](/ko/guides/crews/first-crew) | 과제가 창의적 사고, 탐색, 적응이 필요할 때 |
| **콘텐츠 생성** | [크루](/ko/guides/crews/first-crew) | 기사, 보고서, 마케팅 자료 등 협업형 생성 |
| **의사결정 워크플로우** | [플로우](/ko/guides/flows/first-flow) | 예측 가능하고 감사 가능한 의사결정 경로 및 정밀 제어가 필요할 때 |
| **API 오케스트레이션** | [플로우](/ko/guides/flows/first-flow) | 특정 순서로 여러 외부 서비스에 신뢰성 있게 통합할 때 |
| **하이브리드 애플리케이션** | 혼합 접근 방식 | [플로우](/ko/guides/flows/first-flow)로 전체 프로세스를 오케스트레이션하고, [크루](/ko/guides/crews/first-crew)로 복잡한 하위 작업을 처리 |
| **개방형 연구** | [Crew](/ko/guides/crews/first-crew) | 창의적 사고, 탐색, 적응이 필요한 작업에 적합 |
| **콘텐츠 생성** | [Crew](/ko/guides/crews/first-crew) | 기사, 보고서, 마케팅 자료 등 협업형 생성에 적합 |
| **의사결정 workflow** | [Flow](/ko/guides/flows/first-flow) | 예측 가능하고 감사 가능한 의사결정 경로 및 정밀 제어가 필요할 때 |
| **API orchestration** | [Flow](/ko/guides/flows/first-flow) | 특정 순서로 여러 외부 서비스에 신뢰성 있게 통합할 때 |
| **하이브리드 애플리케이션** | 혼합 접근 방식 | [Flow](/ko/guides/flows/first-flow)로 전체 프로세스를 orchestration하고, [Crew](/ko/guides/crews/first-crew)로 복잡한 하위 작업을 처리 |
### 의사결정 프레임워크
@@ -112,8 +112,8 @@ CrewAI는 고수준의 간편함과 정밀한 저수준 제어를 모두 제공
## CrewAI를 선택해야 하는 이유?
- 🧠 **자율적 운영**: 에이전트가 자신의 역할과 사용 가능한 도구를 바탕으로 지능적인 결정을 내립니다
- 📝 **자연스러운 상호작용**: 에이전트가 인간 팀원처럼 소통하고 협업합니다
- 🧠 **자율적 운영**: agent가 자신의 역할과 사용 가능한 도구를 바탕으로 지능적인 결정을 내립니다
- 📝 **자연스러운 상호작용**: agent가 인간 팀원처럼 소통하고 협업합니다
- 🛠️ **확장 가능한 설계**: 새로운 도구, 역할, 기능을 쉽게 추가할 수 있습니다
- 🚀 **프로덕션 준비 완료**: 실제 환경에서의 신뢰성과 확장성을 고려하여 구축되었습니다
- 🔒 **보안 중심**: 엔터프라이즈 보안 요구 사항을 고려하여 설계되었습니다
@@ -134,7 +134,7 @@ CrewAI는 고수준의 간편함과 정밀한 저수준 제어를 모두 제공
icon="diagram-project"
href="/ko/guides/flows/first-flow"
>
실행을 정밀하게 제어할 수 있는 구조화된, 이벤트 기반 워크플로우를 만드는 방법을 배워보세요.
실행을 정밀하게 제어할 수 있는 구조화된, 이벤트 기반 workflow를 만드는 방법을 배워보세요.
</Card>
</CardGroup>
@@ -151,7 +151,7 @@ CrewAI는 고수준의 간편함과 정밀한 저수준 제어를 모두 제공
icon="bolt"
href="ko/quickstart"
>
빠른 시작 가이드를 따라 첫 번째 CrewAI 에이전트를 만들고 직접 경험해 보세요.
빠른 시작 가이드를 따라 첫 번째 CrewAI agent를 만들고 직접 경험해 보세요.
</Card>
<Card
title="커뮤니티 가입하기"

View File

@@ -0,0 +1,7 @@
---
title: "GET /inputs"
description: "Obter entradas necessárias para sua crew"
openapi: "/enterprise-api.pt-BR.yaml GET /inputs"
---

View File

@@ -0,0 +1,7 @@
---
title: "POST /kickoff"
description: "Iniciar a execução da crew"
openapi: "/enterprise-api.pt-BR.yaml POST /kickoff"
---

View File

@@ -0,0 +1,7 @@
---
title: "GET /status/{kickoff_id}"
description: "Obter o status da execução"
openapi: "/enterprise-api.pt-BR.yaml GET /status/{kickoff_id}"
---

View File

@@ -59,6 +59,7 @@ crew = Crew(
| **Output Pydantic** _(opcional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | Um modelo Pydantic para a saída da tarefa. |
| **Callback** _(opcional)_ | `callback` | `Optional[Any]` | Função/objeto a ser executado após a conclusão da tarefa. |
| **Guardrail** _(opcional)_ | `guardrail` | `Optional[Callable]` | Função para validar a saída da tarefa antes de prosseguir para a próxima tarefa. |
| **Max Tentativas Guardrail** _(opcional)_ | `guardrail_max_retries` | `Optional[int]` | Número máximo de tentativas quando a validação do guardrail falha. Padrão é 3. |
## Criando Tarefas
@@ -450,7 +451,7 @@ task = Task(
expected_output="Um objeto JSON válido",
agent=analyst,
guardrail=validate_json_output,
max_retries=3 # Limite de tentativas
guardrail_max_retries=3 # Limite de tentativas
)
```
@@ -935,7 +936,7 @@ task = Task(
description="Gerar dados",
expected_output="Dados válidos",
guardrail=validate_data,
max_retries=5 # Sobrescreve o limite padrão de tentativas
guardrail_max_retries=5 # Sobrescreve o limite padrão de tentativas
)
```

View File

@@ -0,0 +1,171 @@
---
title: "Triggers de Automação"
description: "Execute automaticamente seus workflows CrewAI quando eventos específicos ocorrem em integrações conectadas"
icon: "bolt"
---
Os triggers de automação permitem executar automaticamente suas implantações CrewAI quando eventos específicos ocorrem em suas integrações conectadas, criando workflows poderosos orientados por eventos que respondem a mudanças em tempo real em seus sistemas de negócio.
## Visão Geral
Com triggers de automação, você pode:
- **Responder a eventos em tempo real** - Execute workflows automaticamente quando condições específicas forem atendidas
- **Integrar com sistemas externos** - Conecte com plataformas como Gmail, Outlook, OneDrive, JIRA, Slack, Stripe e muito mais
- **Escalar sua automação** - Lide com eventos de alto volume sem intervenção manual
- **Manter contexto** - Acesse dados do trigger dentro de suas crews e flows
## Gerenciando Triggers de Automação
### Visualizando Triggers Disponíveis
Para acessar e gerenciar seus triggers de automação:
1. Navegue até sua implantação no painel do CrewAI
2. Clique na aba **Triggers** para visualizar todas as integrações de trigger disponíveis
<Frame>
<img src="/images/enterprise/list-available-triggers.png" alt="Lista de triggers de automação disponíveis" />
</Frame>
Esta visualização mostra todas as integrações de trigger disponíveis para sua implantação, junto com seus status de conexão atuais.
### Habilitando e Desabilitando Triggers
Cada trigger pode ser facilmente habilitado ou desabilitado usando o botão de alternância:
<Frame>
<img src="/images/enterprise/trigger-selected.png" alt="Habilitar ou desabilitar triggers com alternância" />
</Frame>
- **Habilitado (alternância azul)**: O trigger está ativo e executará automaticamente sua implantação quando os eventos especificados ocorrerem
- **Desabilitado (alternância cinza)**: O trigger está inativo e não responderá a eventos
Simplesmente clique na alternância para mudar o estado do trigger. As alterações entram em vigor imediatamente.
### Monitorando Execuções de Trigger
Acompanhe o desempenho e histórico de suas execuções acionadas:
<Frame>
<img src="/images/enterprise/list-executions.png" alt="Lista de execuções acionadas por automação" />
</Frame>
## Construindo Automação
Antes de construir sua automação, é útil entender a estrutura dos payloads de trigger que suas crews e flows receberão.
### Repositório de Amostras de Payload
Mantemos um repositório abrangente com amostras de payload de várias fontes de trigger para ajudá-lo a construir e testar suas automações:
**🔗 [Amostras de Payload de Trigger CrewAI Enterprise](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
Este repositório contém:
- **Exemplos reais de payload** de diferentes fontes de trigger (Gmail, Google Drive, etc.)
- **Documentação da estrutura de payload** mostrando o formato e campos disponíveis
### Triggers com Crew
Suas definições de crew existentes funcionam perfeitamente com triggers, você só precisa ter uma tarefa para analisar o payload recebido:
```python
@CrewBase
class MinhaCrewAutomatizada:
@agent
def pesquisador(self) -> Agent:
return Agent(
config=self.agents_config['pesquisador'],
)
@task
def analisar_payload_trigger(self) -> Task:
return Task(
config=self.tasks_config['analisar_payload_trigger'],
agent=self.pesquisador(),
)
@task
def analisar_conteudo_trigger(self) -> Task:
return Task(
config=self.tasks_config['analisar_dados_trigger'],
agent=self.pesquisador(),
)
```
A crew receberá automaticamente e pode acessar o payload do trigger através dos mecanismos de contexto padrão do CrewAI.
### Integração com Flows
Para flows, você tem mais controle sobre como os dados do trigger são tratados:
#### Acessando Payload do Trigger
Todos os métodos `@start()` em seus flows aceitarão um parâmetro adicional chamado `crewai_trigger_payload`:
```python
from crewai.flow import Flow, start, listen
class MeuFlowAutomatizado(Flow):
@start()
def lidar_com_trigger(self, crewai_trigger_payload: dict = None):
"""
Este método start pode receber dados do trigger
"""
if crewai_trigger_payload:
# Processa os dados do trigger
trigger_id = crewai_trigger_payload.get('id')
dados_evento = crewai_trigger_payload.get('payload', {})
# Armazena no estado do flow para uso por outros métodos
self.state.trigger_id = trigger_id
self.state.trigger_type = dados_evento
return dados_evento
# Lida com execução manual
return None
@listen(lidar_com_trigger)
def processar_dados(self, dados_trigger):
"""
Processa os dados do trigger
"""
# ... processa o trigger
```
#### Acionando Crews a partir de Flows
Ao iniciar uma crew dentro de um flow que foi acionado, passe o payload do trigger conforme ele:
```python
@start()
def delegar_para_crew(self, crewai_trigger_payload: dict = None):
"""
Delega processamento para uma crew especializada
"""
crew = MinhaCrewEspecializada()
# Passa o payload do trigger para a crew
resultado = crew.crew().kickoff(
inputs={
'parametro_personalizado': "valor_personalizado",
'crewai_trigger_payload': crewai_trigger_payload
},
)
return resultado
```
## Solução de Problemas
**Trigger não está sendo disparado:**
- Verifique se o trigger está habilitado
- Verifique o status de conexão da integração
**Falhas de execução:**
- Verifique os logs de execução para detalhes do erro
- Se você está desenvolvendo, certifique-se de que as entradas incluem o parâmetro `crewai_trigger_payload` com o payload correto
Os triggers de automação transformam suas implantações CrewAI em sistemas responsivos orientados por eventos que podem se integrar perfeitamente com seus processos de negócio e ferramentas existentes.

View File

@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = ["crewai-tools~=0.62.0"]
tools = ["crewai-tools~=0.62.1"]
embeddings = [
"tiktoken~=0.8.0"
]
@@ -68,6 +68,9 @@ docling = [
aisuite = [
"aisuite>=0.1.10",
]
qdrant = [
"qdrant-client[fastembed]>=1.14.3",
]
[tool.uv]
dev-dependencies = [
@@ -98,6 +101,11 @@ exclude = ["cli/templates"]
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]
[tool.pytest.ini_options]
markers = [
"telemetry: mark test as a telemetry test (don't mock telemetry)",
]
# PyTorch index configuration, since torch 2.5.0 is not compatible with python 3.13
[[tool.uv.index]]
name = "pytorch-nightly"

View File

@@ -54,7 +54,7 @@ def _track_install_async():
_track_install_async()
__version__ = "0.159.0"
__version__ = "0.165.1"
__all__ = [
"Agent",
"Crew",

View File

@@ -1,7 +1,18 @@
import shutil
import subprocess
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -162,7 +173,7 @@ class Agent(BaseAgent):
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output"
description="Function or string description of a guardrail to validate agent output",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
@@ -276,7 +287,7 @@ class Agent(BaseAgent):
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
@@ -309,15 +320,20 @@ class Agent(BaseAgent):
event=MemoryRetrievalStartedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
),
)
start_time = time.time()
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
@@ -330,13 +346,14 @@ class Agent(BaseAgent):
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,

View File

@@ -43,7 +43,6 @@ class CrewAgentExecutorMixin:
metadata={
"observation": self.task.description,
},
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
@@ -65,7 +64,6 @@ class CrewAgentExecutorMixin:
"description": self.task.description,
"messages": self.messages,
},
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to external memory: {e}")
@@ -158,7 +156,9 @@ class CrewAgentExecutorMixin:
self._printer.print(content=prompt, color="bold_yellow")
response = input()
if response.strip() != "":
self._printer.print(content="\nProcessing your feedback...", color="cyan")
self._printer.print(
content="\nProcessing your feedback...", color="cyan"
)
return response
finally:
event_listener.formatter.resume_live_updates()

View File

@@ -8,13 +8,13 @@ from .cache.cache_handler import CacheHandler
class ToolsHandler:
"""Callback handler for tool usage."""
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
last_used_tool: Optional[ToolCalling] = None
cache: Optional[CacheHandler]
def __init__(self, cache: Optional[CacheHandler] = None):
"""Initialize the callback handler."""
self.cache = cache
self.last_used_tool = {} # type: ignore # BUG?: same as above
self.last_used_tool = None
def on_tool_use(
self,

View File

@@ -7,7 +7,8 @@ from rich.console import Console
from pydantic import BaseModel, Field
from .utils import TokenManager, validate_jwt_token
from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
@@ -21,10 +22,19 @@ console = Console()
class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
@classmethod
def from_settings(cls):
@@ -44,11 +54,15 @@ class ProviderFactory:
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
class AuthenticationCommand:
def __init__(self):
self.token_manager = TokenManager()
@@ -65,7 +79,7 @@ class AuthenticationCommand:
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE
audience=AUTH0_AUDIENCE,
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
@@ -75,9 +89,7 @@ class AuthenticationCommand:
return self._poll_for_token(device_code_data)
def _get_device_code(
self
) -> Dict[str, Any]:
def _get_device_code(self) -> Dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
@@ -86,7 +98,9 @@ class AuthenticationCommand:
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return response.json()
@@ -97,9 +111,7 @@ class AuthenticationCommand:
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(
self, device_code_data: Dict[str, Any]
) -> None:
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
@@ -112,7 +124,9 @@ class AuthenticationCommand:
attempts = 0
while True and attempts < 10:
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
response = requests.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:

View File

@@ -1,4 +1,4 @@
from .utils import TokenManager
from crewai.cli.shared.token_manager import TokenManager
class AuthError(Exception):

View File

@@ -1,12 +1,5 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
import jwt
from jwt import PyJWKClient
from cryptography.fernet import Fernet
def validate_jwt_token(
@@ -67,118 +60,3 @@ def validate_jwt_token(
raise Exception(f"JWKS or key processing error: {str(e)}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()

View File

@@ -11,6 +11,7 @@ from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
from crewai.cli.shared.token_manager import TokenManager
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
@@ -74,12 +76,12 @@ class Settings(BaseModel):
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
)
oauth2_audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
)
oauth2_client_id: str = Field(
@@ -89,7 +91,7 @@ class Settings(BaseModel):
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
@@ -116,6 +118,7 @@ class Settings(BaseModel):
"""Reset all settings to default values"""
self._reset_user_settings()
self._reset_cli_settings()
self._clear_auth_tokens()
self.dump()
def dump(self) -> None:
@@ -139,3 +142,7 @@ class Settings(BaseModel):
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
TokenManager().clear_tokens()

View File

@@ -117,9 +117,6 @@ class PlusAPI:
def get_organizations(self) -> requests.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def send_trace_batch(self, payload) -> requests.Response:
return self._make_request("POST", self.TRACING_RESOURCE, json=payload)
def initialize_trace_batch(self, payload) -> requests.Response:
return self._make_request(
"POST", f"{self.TRACING_RESOURCE}/batches", json=payload
@@ -135,6 +132,7 @@ class PlusAPI:
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def send_ephemeral_trace_events(
@@ -144,6 +142,7 @@ class PlusAPI:
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:

View File

View File

@@ -0,0 +1,139 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
from cryptography.fernet import Fernet
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()
def delete_secure_file(self, filename: str) -> None:
"""
Delete the secure file.
:param filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.159.0,<1.0.0"
"crewai[tools]>=0.165.1,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.159.0,<1.0.0",
"crewai[tools]>=0.165.1,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.159.0"
"crewai[tools]>=0.165.1"
]
[tool.crewai]

View File

@@ -44,8 +44,9 @@ def migrate_pyproject(input_file, output_file):
]
new_pyproject["project"]["requires-python"] = poetry_data.get("python")
else:
# If it's already in the new format, just copy the project section
# If it's already in the new format, just copy the project and tool sections
new_pyproject["project"] = pyproject_data.get("project", {})
new_pyproject["tool"] = pyproject_data.get("tool", {})
# Migrate or copy dependencies
if "dependencies" in new_pyproject["project"]:

View File

@@ -79,7 +79,6 @@ from crewai.utilities.events.listeners.tracing.trace_listener import (
from crewai.utilities.events.listeners.tracing.utils import (
is_tracing_enabled,
on_first_execution_tracing_confirmation,
)
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
@@ -286,8 +285,6 @@ class Crew(FlowTrackable, BaseModel):
self._cache_handler = CacheHandler()
event_listener = EventListener()
if on_first_execution_tracing_confirmation():
self.tracing = True
if is_tracing_enabled() or self.tracing:
trace_listener = TraceCollectionListener()
@@ -639,6 +636,7 @@ class Crew(FlowTrackable, BaseModel):
self._inputs = inputs
self._interpolate_inputs(inputs)
self._set_tasks_callbacks()
self._set_allow_crewai_trigger_context_for_first_task()
i18n = I18N(prompt_file=self.prompt_file)
@@ -1508,3 +1506,18 @@ class Crew(FlowTrackable, BaseModel):
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()
def _set_allow_crewai_trigger_context_for_first_task(self):
crewai_trigger_payload = self._inputs and self._inputs.get(
"crewai_trigger_payload"
)
able_to_inject = (
self.tasks and self.tasks[0].allow_crewai_trigger_context is None
)
if (
self.process == Process.sequential
and crewai_trigger_payload
and able_to_inject
):
self.tasks[0].allow_crewai_trigger_context = True

View File

@@ -1,5 +1,5 @@
import threading
from typing import Any
from typing import Any, Optional
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult, AggregationStrategy
from crewai.agent import Agent
@@ -15,10 +15,11 @@ from crewai.utilities.events.agent_events import LiteAgentExecutionCompletedEven
from crewai.experimental.evaluation.base_evaluator import AgentAggregatedEvaluationResult, EvaluationScore, MetricCategory
class ExecutionState:
current_agent_id: Optional[str] = None
current_task_id: Optional[str] = None
def __init__(self):
self.traces = {}
self.current_agent_id: str | None = None
self.current_task_id: str | None = None
self.iteration = 1
self.iterations_results = {}
self.agent_evaluators = {}

View File

@@ -40,7 +40,6 @@ from crewai.utilities.events.listeners.tracing.trace_listener import (
)
from crewai.utilities.events.listeners.tracing.utils import (
is_tracing_enabled,
on_first_execution_tracing_confirmation,
)
from crewai.utilities.printer import Printer
@@ -475,15 +474,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_outputs: List[Any] = [] # List to store all method outputs
self._completed_methods: Set[str] = set() # Track completed methods for reload
self._persistence: Optional[FlowPersistence] = persistence
self._is_execution_resuming: bool = False
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
if (
on_first_execution_tracing_confirmation()
or is_tracing_enabled()
or self.tracing
):
if is_tracing_enabled() or self.tracing:
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
@@ -834,6 +830,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
self._method_outputs.clear()
else:
# We're restoring from persistence, set the flag
self._is_execution_resuming = True
if inputs:
# Override the id in the state if it exists in inputs
@@ -885,6 +884,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
]
await asyncio.gather(*tasks)
# Clear the resumption flag after initial execution completes
self._is_execution_resuming = False
final_output = self._method_outputs[-1] if self._method_outputs else None
crewai_event_bus.emit(
@@ -918,17 +920,56 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence
- Skips execution if method was already completed (e.g., after reload)
- Automatically injects crewai_trigger_payload if available in flow inputs
"""
if start_method_name in self._completed_methods:
last_output = self._method_outputs[-1] if self._method_outputs else None
await self._execute_listeners(start_method_name, last_output)
return
if self._is_execution_resuming:
# During resumption, skip execution but continue listeners
last_output = self._method_outputs[-1] if self._method_outputs else None
await self._execute_listeners(start_method_name, last_output)
return
# For cyclic flows, clear from completed to allow re-execution
self._completed_methods.discard(start_method_name)
method = self._methods[start_method_name]
enhanced_method = self._inject_trigger_payload_for_start_method(method)
result = await self._execute_method(
start_method_name, self._methods[start_method_name]
start_method_name, enhanced_method
)
await self._execute_listeners(start_method_name, result)
def _inject_trigger_payload_for_start_method(self, original_method: Callable) -> Callable:
def prepare_kwargs(*args, **kwargs):
inputs = baggage.get_baggage("flow_inputs") or {}
trigger_payload = inputs.get("crewai_trigger_payload")
sig = inspect.signature(original_method)
accepts_trigger_payload = "crewai_trigger_payload" in sig.parameters
if trigger_payload is not None and accepts_trigger_payload:
kwargs["crewai_trigger_payload"] = trigger_payload
elif trigger_payload is not None:
self._log_flow_event(
f"Trigger payload available but {original_method.__name__} doesn't accept crewai_trigger_payload parameter",
color="yellow"
)
return args, kwargs
if asyncio.iscoroutinefunction(original_method):
async def enhanced_method(*args, **kwargs):
args, kwargs = prepare_kwargs(*args, **kwargs)
return await original_method(*args, **kwargs)
else:
def enhanced_method(*args, **kwargs):
args, kwargs = prepare_kwargs(*args, **kwargs)
return original_method(*args, **kwargs)
enhanced_method.__name__ = original_method.__name__
enhanced_method.__doc__ = original_method.__doc__
return enhanced_method
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
@@ -1020,11 +1061,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
for router_name in routers_triggered:
await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path
router_result = self._method_outputs[-1]
router_result = (
self._method_outputs[-1] if self._method_outputs else None
)
if router_result: # Only add non-None results
router_results.append(router_result)
current_trigger = (
router_result # Update for next iteration of router chain
str(router_result)
if router_result is not None
else "" # Update for next iteration of router chain
)
# Now execute normal listeners for all router results and the original trigger
@@ -1042,6 +1087,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
]
await asyncio.gather(*tasks)
if current_trigger in router_results:
# Find start methods triggered by this router result
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
# Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag
# to allow cyclic execution
was_resuming = self._is_execution_resuming
self._is_execution_resuming = False
await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
@@ -1079,6 +1142,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
if router_only != is_router:
continue
if not router_only and listener_name in self._start_methods:
continue
if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
@@ -1128,10 +1194,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow.
"""
# TODO: greyson fix
# if listener_name in self._completed_methods:
# await self._execute_listeners(listener_name, None)
# return
if listener_name in self._completed_methods:
if self._is_execution_resuming:
# During resumption, skip execution but continue listeners
await self._execute_listeners(listener_name, None)
return
# For cyclic flows, clear from completed to allow re-execution
self._completed_methods.discard(listener_name)
try:
method = self._methods[listener_name]

View File

@@ -11,6 +11,7 @@ import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import warnings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
@@ -85,6 +86,14 @@ class KnowledgeStorage(BaseKnowledgeStorage):
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self.app = create_persistent_client(
path=os.path.join(db_storage_path(), "knowledge"),
settings=Settings(allow_reset=True),

View File

@@ -69,12 +69,7 @@ from crewai.utilities.events.agent_events import (
LiteAgentExecutionStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -519,19 +514,6 @@ class LiteAgent(FlowTrackable, BaseModel):
enforce_rpm_limit(self.request_within_rpm_limit)
llm = cast(LLM, self.llm)
model = llm.model if hasattr(llm, "model") else "unknown"
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
messages=self._messages,
tools=None,
callbacks=self._callbacks,
from_agent=self,
model=model,
),
)
try:
answer = get_llm_response(
llm=cast(LLM, self.llm),
@@ -541,23 +523,7 @@ class LiteAgent(FlowTrackable, BaseModel):
from_agent=self,
)
# Emit LLM call completed event
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(
messages=self._messages,
response=answer,
call_type=LLMCallType.LLM_CALL,
from_agent=self,
model=model,
),
)
except Exception as e:
# Emit LLM call failed event
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=str(e), from_agent=self),
)
raise e
formatted_answer = process_llm_response(answer, self.use_stop_words)

View File

@@ -851,7 +851,9 @@ class LLM(BaseLLM):
return tool_calls
# --- 7) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
@@ -868,6 +870,8 @@ class LLM(BaseLLM):
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Optional[str]:
"""Handle a tool call from the LLM.
@@ -902,6 +906,8 @@ class LLM(BaseLLM):
event=ToolUsageStartedEvent(
tool_name=function_name,
tool_args=function_args,
from_agent=from_agent,
from_task=from_task,
),
)
@@ -914,12 +920,17 @@ class LLM(BaseLLM):
tool_args=function_args,
started_at=started_at,
finished_at=datetime.now(),
from_task=from_task,
from_agent=from_agent,
),
)
# --- 3.3) Emit success event
self._handle_emit_call_events(
response=result, call_type=LLMCallType.TOOL_CALL
response=result,
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
)
return result
except Exception as e:
@@ -1139,7 +1150,11 @@ class LLM(BaseLLM):
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
# Ollama doesn't supports last message to be 'assistant'
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
if (
"ollama" in self.model.lower()
and messages
and messages[-1]["role"] == "assistant"
):
return messages + [{"role": "user", "content": ""}]
# Handle Anthropic models

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, TYPE_CHECKING
from crewai.memory import (
EntityMemory,
@@ -7,6 +7,10 @@ from crewai.memory import (
ShortTermMemory,
)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class ContextualMemory:
def __init__(
@@ -15,11 +19,28 @@ class ContextualMemory:
ltm: LongTermMemory,
em: EntityMemory,
exm: ExternalMemory,
agent: Optional["Agent"] = None,
task: Optional["Task"] = None,
):
self.stm = stm
self.ltm = ltm
self.em = em
self.exm = exm
self.agent = agent
self.task = task
if self.stm is not None:
self.stm.agent = self.agent
self.stm.task = self.task
if self.ltm is not None:
self.ltm.agent = self.agent
self.ltm.task = self.task
if self.em is not None:
self.em.agent = self.agent
self.em.task = self.task
if self.exm is not None:
self.exm.agent = self.agent
self.exm.task = self.task
def build_context_for_task(self, task, context) -> str:
"""
@@ -49,10 +70,7 @@ class ContextualMemory:
stm_results = self.stm.search(query)
formatted_results = "\n".join(
[
f"- {result['context']}"
for result in stm_results
]
[f"- {result['context']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
@@ -89,10 +107,7 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[
f"- {result['context']}"
for result in em_results
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
return f"Entities:\n{formatted_results}" if em_results else ""

View File

@@ -35,7 +35,7 @@ class EntityMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
config = embedder_config.get("config")
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (
@@ -60,6 +60,8 @@ class EntityMemory(Memory):
event=MemorySaveStartedEvent(
metadata=item.metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
@@ -85,6 +87,8 @@ class EntityMemory(Memory):
metadata=item.metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
@@ -94,6 +98,8 @@ class EntityMemory(Memory):
metadata=item.metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
@@ -111,6 +117,8 @@ class EntityMemory(Memory):
limit=limit,
score_threshold=score_threshold,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
@@ -129,6 +137,8 @@ class EntityMemory(Memory):
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)

View File

@@ -53,7 +53,6 @@ class ExternalMemory(Memory):
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
@@ -61,24 +60,30 @@ class ExternalMemory(Memory):
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
agent_role=agent,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
super().save(value=item.value, metadata=item.metadata, agent=item.agent)
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
super().save(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
agent_role=agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
@@ -87,9 +92,10 @@ class ExternalMemory(Memory):
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
agent_role=agent,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
@@ -107,6 +113,8 @@ class ExternalMemory(Memory):
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
@@ -125,6 +133,8 @@ class ExternalMemory(Memory):
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)

View File

@@ -37,13 +37,17 @@ class LongTermMemory(Memory):
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
@@ -59,6 +63,8 @@ class LongTermMemory(Memory):
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
@@ -74,13 +80,19 @@ class LongTermMemory(Memory):
)
raise
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
self,
task: str,
latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
@@ -96,6 +108,8 @@ class LongTermMemory(Memory):
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class Memory(BaseModel):
"""
@@ -12,19 +16,38 @@ class Memory(BaseModel):
crew: Optional[Any] = None
storage: Any
_agent: Optional["Agent"] = None
_task: Optional["Task"] = None
def __init__(self, storage: Any, **data: Any):
super().__init__(storage=storage, **data)
@property
def task(self) -> Optional["Task"]:
"""Get the current task associated with this memory."""
return self._task
@task.setter
def task(self, task: Optional["Task"]) -> None:
"""Set the current task associated with this memory."""
self._task = task
@property
def agent(self) -> Optional["Agent"]:
"""Get the current agent associated with this memory."""
return self._agent
@agent.setter
def agent(self, agent: Optional["Agent"]) -> None:
"""Set the current agent associated with this memory."""
self._agent = agent
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
self.storage.save(value, metadata)

View File

@@ -37,7 +37,7 @@ class ShortTermMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
config = embedder_config.get("config")
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (
@@ -57,34 +57,42 @@ class ShortTermMemory(Memory):
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
agent_role=agent,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
item = ShortTermMemoryItem(
data=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
if self._memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
item.data = (
f"Remember the following insights from Agent run: {item.data}"
)
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
super().save(value=item.data, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
agent_role=agent,
# agent_role=agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
@@ -93,9 +101,10 @@ class ShortTermMemory(Memory):
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
agent_role=agent,
error=str(e),
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
@@ -113,6 +122,8 @@ class ShortTermMemory(Memory):
limit=limit,
score_threshold=score_threshold,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
@@ -131,6 +142,8 @@ class ShortTermMemory(Memory):
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)

View File

@@ -12,6 +12,7 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
import warnings
@contextlib.contextmanager
@@ -62,6 +63,14 @@ class RAGStorage(BaseRAGStorage):
def _initialize_app(self):
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self._set_embedder_config()
self.app = create_persistent_client(

View File

@@ -1 +1,58 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
__path__ = _module_path
__file__ = _module_file
def __init__(self, module_name: str):
"""Initialize the module wrapper.
Args:
module_name: Name of the module.
"""
super().__init__(module_name)
def __setattr__(self, name: str, value: RagConfigType) -> None:
"""Set module attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
if name == "config":
return set_rag_config(value)
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
def __getattr__(self, name: str) -> Any:
"""Get module attributes.
Args:
name: Attribute name.
Returns:
The requested attribute.
Raises:
AttributeError: If attribute doesn't exist.
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
sys.modules[__name__] = _RagModule(__name__)

View File

View File

@@ -0,0 +1,567 @@
"""ChromaDB client implementation."""
from typing import Any
from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
QueryResult,
)
from typing_extensions import Unpack
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionCreateParams,
ChromaDBCollectionSearchParams,
)
from crewai.rag.chromadb.utils import (
_extract_search_params,
_is_async_client,
_is_sync_client,
_prepare_documents_for_chromadb,
_process_query_results,
)
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
)
from crewai.rag.types import SearchResult
class ChromaDBClient(BaseClient):
"""ChromaDB implementation of the BaseClient protocol.
Provides vector database operations for ChromaDB, supporting both
synchronous and asynchronous clients.
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
"""
self.client = client
self.embedding_function = embedding_function
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB.
Uses the client's default embedding function if none provided.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method create_collection() requires a ClientAPI. "
"Use acreate_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
self.client.create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
async def acreate_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB asynchronously.
Creates a new collection with the specified name and optional configuration.
If an embedding function is not provided, uses the client's default embedding function.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.acreate_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
"Use create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
await self.client.create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
def get_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB Collection object.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> collection = client.get_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method get_or_create_collection() requires a ClientAPI. "
"Use aget_or_create_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
async def aget_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB AsyncCollection object.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... collection = await client.aget_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
"Use get_or_create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return await self.client.get_or_create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method add_documents() requires a ClientAPI. "
"Use aadd_documents() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
name=collection_name,
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
collection.add(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
"Use add_documents() for ClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
name=collection_name,
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
await collection.add(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method search() requires a ClientAPI. "
"Use asearch() for AsyncClientAPI."
)
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
name=params.collection_name,
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
async def asearch(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method asearch() requires an AsyncClientAPI. "
"Use search() for ClientAPI."
)
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
name=params.collection_name,
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_documents")
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method delete_collection() requires a ClientAPI. "
"Use adelete_collection() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_documents")
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
"Use delete_collection() for ClientAPI."
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(name=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.reset() # Removes ALL data from ChromaDB
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method reset() requires a ClientAPI. "
"Use areset() for AsyncClientAPI."
)
self.client.reset()
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.areset() # Removes ALL data from ChromaDB
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method areset() requires an AsyncClientAPI. "
"Use reset() for ClientAPI."
)
await self.client.reset()

View File

@@ -0,0 +1,59 @@
"""ChromaDB configuration model."""
import warnings
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
)
warnings.filterwarnings(
"ignore",
message=".*Mixing V1 models and V2 models.*",
category=UserWarning,
module="pydantic._internal._generate_schema",
)
def _default_settings() -> Settings:
"""Create default ChromaDB settings.
Returns:
Settings with persistent storage and reset enabled.
"""
return Settings(
persist_directory=DEFAULT_STORAGE_PATH,
allow_reset=True,
is_persistent=True,
)
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
"""Create default ChromaDB embedding function.
Returns:
Default embedding function using all-MiniLM-L6-v2 via ONNX.
"""
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
@pyd_dataclass(frozen=True)
class ChromaDBConfig(BaseRagConfig):
"""Configuration for ChromaDB client."""
provider: Literal["chromadb"] = field(default="chromadb", init=False)
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
settings: Settings = field(default_factory=_default_settings)
embedding_function: ChromaEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,10 @@
"""Constants for ChromaDB configuration."""
import os
from typing import Final
from crewai.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")

View File

@@ -0,0 +1,24 @@
"""Factory functions for creating ChromaDB clients."""
from chromadb import Client
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""Create a ChromaDBClient from configuration.
Args:
config: ChromaDB configuration object.
Returns:
Configured ChromaDBClient instance.
"""
return ChromaDBClient(
client=Client(
settings=config.settings, tenant=config.tenant, database=config.database
),
embedding_function=config.embedding_function,
)

View File

@@ -0,0 +1,102 @@
"""Type definitions specific to ChromaDB implementation."""
from collections.abc import Mapping
from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include,
Loadable,
Where,
WhereDocument,
)
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class PreparedDocuments(NamedTuple):
"""Prepared documents ready for ChromaDB insertion.
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings
"""
ids: list[str]
texts: list[str]
metadatas: list[Mapping[str, str | int | float | bool]]
class ExtractedSearchParams(NamedTuple):
"""Extracted search parameters for ChromaDB queries.
Attributes:
collection_name: Name of the collection to search
query: Search query text
limit: Maximum number of results
metadata_filter: Optional metadata filter
score_threshold: Optional minimum similarity score
where: Optional ChromaDB where clause
where_document: Optional ChromaDB document filter
include: Fields to include in results
"""
collection_name: str
query: str
limit: int
metadata_filter: dict[str, Any] | None
score_threshold: float | None
where: Where | None
where_document: WhereDocument | None
include: Include
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
"""Parameters for creating a ChromaDB collection.
This class extends BaseCollectionParams to include any additional
parameters specific to ChromaDB collection creation.
"""
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable]
data_loader: DataLoader[Loadable]
get_or_create: bool
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
"""Parameters for searching a ChromaDB collection.
This class extends BaseCollectionSearchParams to include ChromaDB-specific
search parameters like where clauses and include options.
"""
where: Where
where_document: WhereDocument
include: Include

View File

@@ -0,0 +1,218 @@
"""Utility functions for ChromaDB client implementation."""
import hashlib
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionSearchParams,
ExtractedSearchParams,
PreparedDocuments,
)
from crewai.rag.types import BaseRecord, SearchResult
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
"""Type guard to check if the client is a synchronous ClientAPI.
Args:
client: The client to check.
Returns:
True if the client is a ClientAPI, False otherwise.
"""
return isinstance(client, ClientAPI)
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
Args:
client: The client to check.
Returns:
True if the client is an AsyncClientAPI, False otherwise.
"""
return isinstance(client, AsyncClientAPI)
def _prepare_documents_for_chromadb(
documents: list[BaseRecord],
) -> PreparedDocuments:
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
Args:
documents: List of BaseRecord documents to prepare.
Returns:
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
"""
ids: list[str] = []
texts: list[str] = []
metadatas: list[Mapping[str, str | int | float | bool]] = []
for doc in documents:
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
ids.append(content_hash)
texts.append(doc["content"])
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
else:
metadatas.append(metadata)
else:
metadatas.append({})
return PreparedDocuments(ids, texts, metadatas)
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:
"""Extract search parameters from kwargs.
Args:
kwargs: Keyword arguments containing search parameters.
Returns:
ExtractedSearchParams with all extracted parameters.
"""
return ExtractedSearchParams(
collection_name=kwargs["collection_name"],
query=kwargs["query"],
limit=kwargs.get("limit", 10),
metadata_filter=kwargs.get("metadata_filter"),
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
),
)
def _convert_distance_to_score(
distance: float,
distance_metric: Literal["l2", "cosine", "ip"],
) -> float:
"""Convert ChromaDB distance to similarity score.
Notes:
Assuming all embedding are unit-normalized for now, including custom embeddings.
Args:
distance: The distance value from ChromaDB.
distance_metric: The distance metric used ("l2", "cosine", or "ip").
Returns:
Similarity score in range [0, 1] where 1 is most similar.
"""
if distance_metric == "cosine":
score = 1.0 - 0.5 * distance
return max(0.0, min(1.0, score))
raise ValueError(f"Unsupported distance metric: {distance_metric}")
def _convert_chromadb_results_to_search_results(
results: QueryResult,
include: Include,
distance_metric: Literal["l2", "cosine", "ip"],
score_threshold: float | None = None,
) -> list[SearchResult]:
"""Convert ChromaDB query results to SearchResult format.
Args:
results: ChromaDB query results.
include: List of fields that were included in the query.
distance_metric: The distance metric used by the collection.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
ids = results["ids"][0] if results.get("ids") else []
documents_list = results.get("documents")
documents = (
documents_list[0] if documents_list and "documents" in include_strings else []
)
metadatas_list = results.get("metadatas")
metadatas = (
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
)
distances_list = results.get("distances")
distances = (
distances_list[0] if distances_list and "distances" in include_strings else []
)
for i, doc_id in enumerate(ids):
if not distances or i >= len(distances):
continue
distance = distances[i]
score = _convert_distance_to_score(
distance=distance, distance_metric=distance_metric
)
if score_threshold and score < score_threshold:
continue
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"score": score,
}
search_results.append(result)
return search_results
def _process_query_results(
collection: Collection | AsyncCollection,
results: QueryResult,
params: ExtractedSearchParams,
) -> list[SearchResult]:
"""Process ChromaDB query results and convert to SearchResult format.
Args:
collection: The ChromaDB collection (sync or async) that was queried.
results: Raw query results from ChromaDB.
params: The search parameters used for the query.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
distance_metric = cast(
Literal["l2", "cosine", "ip"],
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
)
return _convert_chromadb_results_to_search_results(
results=results,
include=params.include,
distance_metric=distance_metric,
score_threshold=params.score_threshold,
)

View File

@@ -0,0 +1 @@
"""RAG client configuration management using ContextVars for thread-safe provider switching."""

View File

@@ -0,0 +1,16 @@
"""Base configuration class for RAG providers."""
from dataclasses import field
from typing import Any
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.types import SupportedProvider
@pyd_dataclass(frozen=True)
class BaseRagConfig:
"""Base class for RAG configuration with Pydantic serialization support."""
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)

View File

@@ -0,0 +1,8 @@
"""Constants for RAG configuration."""
from typing import Final
DISCRIMINATOR: Final[str] = "provider"
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"

View File

@@ -0,0 +1 @@
"""Optional imports for RAG configuration providers."""

View File

@@ -0,0 +1,26 @@
"""Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class _MissingProvider:
"""Base class for missing provider configurations.
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
default="__missing__"
)
def __post_init__(self) -> None:
"""Raises error indicating the provider is not installed."""
raise RuntimeError(
f"provider '{self.provider}' requested but not installed. "
f"Install the extra: `uv add crewai'[{self.provider}]'`."
)

View File

@@ -0,0 +1,37 @@
"""Protocol definitions for RAG factory modules."""
from __future__ import annotations
from typing import Protocol, TYPE_CHECKING
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
from crewai.rag.elasticsearch.config import ElasticsearchConfig
class ChromaFactoryModule(Protocol):
"""Protocol for ChromaDB factory module."""
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
"""Creates a ChromaDB client from configuration."""
...
class QdrantFactoryModule(Protocol):
"""Protocol for Qdrant factory module."""
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...
class ElasticsearchFactoryModule(Protocol):
"""Protocol for Elasticsearch factory module."""
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
"""Creates an Elasticsearch client from configuration."""
...

View File

@@ -0,0 +1,29 @@
"""Provider-specific missing configuration classes."""
from typing import Literal
from dataclasses import field
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.base import _MissingProvider
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingChromaDBConfig(_MissingProvider):
"""Placeholder for missing ChromaDB configuration."""
provider: Literal["chromadb"] = field(default="chromadb")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingElasticsearchConfig(_MissingProvider):
"""Placeholder for missing Elasticsearch configuration."""
provider: Literal["elasticsearch"] = field(default="elasticsearch")

View File

@@ -0,0 +1,8 @@
"""Type definitions for optional imports."""
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb", "qdrant", "elasticsearch"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -0,0 +1,44 @@
"""Type definitions for RAG configuration."""
from typing import Annotated, TypeAlias, TYPE_CHECKING
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR
# Linter freaks out on conditional imports, assigning in the type checking fixes it
if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
ChromaDBConfig = ChromaDBConfig_
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
ElasticsearchConfig = ElasticsearchConfig_
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingChromaDBConfig as ChromaDBConfig,
)
try:
from crewai.rag.qdrant.config import QdrantConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingQdrantConfig as QdrantConfig,
)
try:
from crewai.rag.elasticsearch.config import ElasticsearchConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingElasticsearchConfig as ElasticsearchConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -0,0 +1,86 @@
"""RAG client configuration utilities."""
from contextvars import ContextVar
from pydantic import BaseModel, Field
from crewai.utilities.import_utils import require
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
)
from crewai.rag.factory import create_client
class RagContext(BaseModel):
"""Context holding RAG configuration and client instance."""
config: RagConfigType = Field(..., description="RAG provider configuration")
client: BaseClient | None = Field(
default=None, description="Instantiated RAG client"
)
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
def set_rag_config(config: RagConfigType) -> None:
"""Set global RAG client configuration and instantiate the client.
Args:
config: The RAG client configuration (ChromaDBConfig).
"""
client = create_client(config)
context = RagContext(config=config, client=client)
_rag_context.set(context)
def get_rag_config() -> RagConfigType:
"""Get current RAG configuration.
Returns:
The current RAG configuration object.
"""
context = _rag_context.get()
if context is None:
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
default_config = config_class()
set_rag_config(default_config)
context = _rag_context.get()
if context is None or context.config is None:
raise ValueError(
"RAG configuration is not set. Please set the RAG config first."
)
return context.config
def get_rag_client() -> BaseClient:
"""Get the current RAG client instance.
Returns:
The current RAG client, creating one if needed.
"""
context = _rag_context.get()
if context is None:
get_rag_config()
context = _rag_context.get()
if context and context.client is None:
context.client = create_client(context.config)
if context is None or context.client is None:
raise ValueError(
"RAG client is not configured. Please set the RAG config first."
)
return context.client
def clear_rag_config() -> None:
"""Clear the current RAG configuration and client, reverting to defaults."""
_rag_context.set(None)

View File

@@ -0,0 +1 @@
"""Core abstract base classes and protocols for RAG systems."""

View File

@@ -0,0 +1,446 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
SearchResult,
)
class BaseCollectionParams(TypedDict):
"""Base parameters for collection operations.
Attributes:
collection_name: The name of the collection/index to operate on.
"""
collection_name: Required[
Annotated[
str,
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
]
]
class BaseCollectionAddParams(BaseCollectionParams):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
"""
documents: list[BaseRecord]
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
"""Parameters for searching within a collection.
Extends BaseCollectionParams with search-specific optional fields.
All fields except collection_name and query are optional.
Attributes:
query: The text query to search for (required).
limit: Maximum number of results to return.
metadata_filter: Filter results by metadata fields.
score_threshold: Minimum similarity score for results (0-1).
"""
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
score_threshold: float
@runtime_checkable
class BaseClient(Protocol):
"""Protocol for vector store client implementations.
This protocol defines the interface that all vector store client implementations
must follow. It provides a consistent API for storing and retrieving
documents with their vector embeddings across different vector database
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
handle connection management, data persistence, and vector similarity
search operations specific to their backend.
Implementation Guidelines:
Implementations should accept BaseClientParams in their constructor to allow
passing pre-configured client instances:
class MyVectorClient:
def __init__(self, client: Any | None = None, **kwargs):
if client:
self.client = client
else:
self.client = self._create_default_client(**kwargs)
Notes:
This protocol replaces the former BaseRAGStorage abstraction,
providing a cleaner interface for vector store operations.
Attributes:
embedding_function: Callable that takes a list of text strings
and returns a list of embedding vectors. Implementations
should always provide a default embedding function.
client: The underlying vector database client instance. This could be
passed via BaseClientParams during initialization or created internally.
"""
client: Any
embedding_function: EmbeddingFunction
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
@abstractmethod
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database asynchronously.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
"""Get an existing collection or create it if it doesn't exist.
This method provides a convenient way to ensure a collection exists
without having to check for its existence first.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
This could be a collection reference, ID, or client object.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def aget_or_create_collection(
self, **kwargs: Unpack[BaseCollectionParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
This method performs an upsert operation - if a document with the same ID
already exists, it will be updated with the new content and metadata.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>> client = ChromaDBClient()
>>>
>>> records: list[BaseRecord] = [
... {
... "content": "Machine learning basics",
... "metadata": {"source": "file3", "topic": "ML"}
... },
... {
... "doc_id": "custom_id",
... "content": "Deep learning fundamentals",
... "metadata": {"source": "file4", "topic": "DL"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records)
>>>
>>> records_with_id: list[BaseRecord] = [
... {
... "doc_id": "nlp_001",
... "content": "Advanced NLP techniques",
... "metadata": {"source": "file5", "topic": "NLP"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
"""
...
@abstractmethod
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>>
>>> async def add_documents():
... client = ChromaDBClient()
...
... records: list[BaseRecord] = [
... {
... "doc_id": "doc2",
... "content": "Async operations in Python",
... "metadata": {"source": "file2", "topic": "async"}
... }
... ]
... await client.aadd_documents(collection_name="my_docs", documents=records)
...
>>> asyncio.run(add_documents())
"""
...
@abstractmethod
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs a vector similarity search to find the most similar documents
to the provided query.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search. The exact
format depends on the backend, but typically supports equality
and range queries on metadata fields.
score_threshold: Optional minimum similarity score threshold. Only
results with scores >= this threshold will be returned. The
score interpretation depends on the distance metric used.
Returns:
A list of SearchResult dictionaries ordered by similarity score in
descending order. Each result contains:
- id: Document ID
- content: Document text content
- metadata: Document metadata
- score: Similarity score (0-1, higher is better)
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>>
>>> results = client.search(
... collection_name="my_docs",
... query="What is machine learning?",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
>>> for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
"""
...
@abstractmethod
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search.
score_threshold: Optional minimum similarity score threshold.
Returns:
A list of SearchResult dictionaries ordered by similarity score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def search_documents():
... client = ChromaDBClient()
... results = await client.asearch(
... collection_name="my_docs",
... query="Python programming best practices",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
... for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
...
>>> asyncio.run(search_documents())
"""
...
@abstractmethod
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
This operation is irreversible and will permanently remove all documents,
embeddings, and metadata associated with the collection.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_docs")
>>> print("Collection 'old_docs' deleted successfully")
"""
...
@abstractmethod
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def delete_old_collection():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_docs")
... print("Collection 'old_docs' deleted successfully")
...
>>> asyncio.run(delete_old_collection())
"""
...
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
This method provides a way to completely clear the vector database,
removing all collections and their contents. Use with caution as
this operation is irreversible.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.reset()
>>> print("Vector database completely reset - all data deleted")
"""
...
@abstractmethod
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def reset_database():
... client = ChromaDBClient()
... await client.areset()
... print("Vector database completely reset - all data deleted")
...
>>> asyncio.run(reset_database())
"""
...

View File

@@ -0,0 +1,26 @@
"""Core exceptions for RAG module."""
class ClientMethodMismatchError(TypeError):
"""Raised when a method is called with the wrong client type.
Typically used when a sync method is called with an async client,
or vice versa.
"""
def __init__(
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
) -> None:
"""Create a ClientMethodMismatchError.
Args:
method_name: Method that was called incorrectly.
expected_client: Required client type.
alt_method: Suggested alternative method.
alt_client: Client type for the alternative method.
"""
message = (
f"Method {method_name}() requires a {expected_client}. "
f"Use {alt_method}() for {alt_client}."
)
super().__init__(message)

View File

@@ -0,0 +1 @@
"""Elasticsearch RAG implementation."""

View File

@@ -0,0 +1,502 @@
"""Elasticsearch client implementation."""
from typing import Any, cast
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
ElasticsearchCollectionCreateParams,
)
from crewai.rag.elasticsearch.utils import (
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_prepare_document_for_elasticsearch,
_process_search_results,
_build_vector_search_query,
_get_index_mapping,
)
from crewai.rag.types import SearchResult
class ElasticsearchClient(BaseClient):
"""Elasticsearch implementation of the BaseClient protocol.
Provides vector database operations for Elasticsearch, supporting both
synchronous and asynchronous clients.
Attributes:
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
embedding_function: Function to generate embeddings for documents.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
def __init__(
self,
client: ElasticsearchClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
vector_dimension: int = 384,
similarity: str = "cosine",
) -> None:
"""Initialize ElasticsearchClient with client and embedding function.
Args:
client: Pre-configured Elasticsearch client instance.
embedding_function: Embedding function for text to vector conversion.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
self.client = client
self.embedding_function = embedding_function
self.vector_dimension = vector_dimension
self.similarity = similarity
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="create_collection",
expected_client="Elasticsearch",
alt_method="acreate_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch asynchronously.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="acreate_collection",
expected_client="AsyncElasticsearch",
alt_method="create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="get_or_create_collection",
expected_client="Elasticsearch",
alt_method="aget_or_create_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
return self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
return self.client.indices.get(index=collection_name)
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aget_or_create_collection",
expected_client="AsyncElasticsearch",
alt_method="get_or_create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
return await self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
return await self.client.indices.get(index=collection_name)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="add_documents",
expected_client="Elasticsearch",
alt_method="aadd_documents",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index asynchronously.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aadd_documents",
expected_client="AsyncElasticsearch",
alt_method="add_documents",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
await self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="search",
expected_client="Elasticsearch",
alt_method="asearch",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync search. "
"Use asearch instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="asearch",
expected_client="AsyncElasticsearch",
alt_method="search",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = await self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="delete_collection",
expected_client="Elasticsearch",
alt_method="adelete_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
self.client.indices.delete(index=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data asynchronously.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="adelete_collection",
expected_client="AsyncElasticsearch",
alt_method="delete_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
await self.client.indices.delete(index=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all indices and data.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="reset",
expected_client="Elasticsearch",
alt_method="areset",
alt_client="AsyncElasticsearch",
)
indices_response = self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
self.client.indices.delete(index=index_name)
async def areset(self) -> None:
"""Reset the vector database by deleting all indices and data asynchronously.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="areset",
expected_client="AsyncElasticsearch",
alt_method="reset",
alt_client="Elasticsearch",
)
indices_response = await self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
await self.client.indices.delete(index=index_name)

View File

@@ -0,0 +1,92 @@
"""Elasticsearch configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.elasticsearch.types import (
ElasticsearchClientParams,
ElasticsearchEmbeddingFunctionWrapper,
)
from crewai.rag.elasticsearch.constants import (
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_VECTOR_DIMENSION,
)
def _default_options() -> ElasticsearchClientParams:
"""Create default Elasticsearch client options.
Returns:
Default options with local Elasticsearch connection.
"""
return ElasticsearchClientParams(
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
use_ssl=False,
verify_certs=False,
timeout=30,
)
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
"""Create default Elasticsearch embedding function.
Returns:
Default embedding function using sentence-transformers.
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embedding = model.encode(text, convert_to_tensor=False)
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
except ImportError:
def fallback_embed_fn(text: str) -> list[float]:
"""Fallback embedding function when sentence-transformers is not available."""
import hashlib
import struct
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
hash_bytes = hash_obj.digest()
vector = []
for i in range(0, len(hash_bytes), 4):
chunk = hash_bytes[i:i+4]
if len(chunk) == 4:
value = struct.unpack('f', chunk)[0]
vector.append(float(value))
while len(vector) < DEFAULT_VECTOR_DIMENSION:
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
return vector[:DEFAULT_VECTOR_DIMENSION]
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
@pyd_dataclass(frozen=True)
class ElasticsearchConfig(BaseRagConfig):
"""Configuration for Elasticsearch client."""
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
options: ElasticsearchClientParams = field(default_factory=_default_options)
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
similarity: str = "cosine"
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,12 @@
"""Constants for Elasticsearch RAG implementation."""
from typing import Final
DEFAULT_HOST: Final[str] = "localhost"
DEFAULT_PORT: Final[int] = 9200
DEFAULT_INDEX_SETTINGS: Final[dict] = {
"number_of_shards": 1,
"number_of_replicas": 0,
}
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_VECTOR_DIMENSION: Final[int] = 384

View File

@@ -0,0 +1,31 @@
"""Factory functions for creating Elasticsearch clients."""
from crewai.rag.elasticsearch.config import ElasticsearchConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
"""Create an ElasticsearchClient from configuration.
Args:
config: Elasticsearch configuration object.
Returns:
Configured ElasticsearchClient instance.
"""
try:
from elasticsearch import Elasticsearch
except ImportError as e:
raise ImportError(
"elasticsearch package is required for Elasticsearch support. "
"Install it with: pip install elasticsearch"
) from e
client = Elasticsearch(**config.options)
return ElasticsearchClient(
client=client,
embedding_function=config.embedding_function,
vector_dimension=config.vector_dimension,
similarity=config.similarity,
)

View File

@@ -0,0 +1,93 @@
"""Type definitions for Elasticsearch RAG implementation."""
from typing import Any, Protocol, Union, TYPE_CHECKING
from typing_extensions import NotRequired, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from typing import TypeAlias
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
else:
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
except ImportError:
ElasticsearchClientType = Any
class ElasticsearchClientParams(TypedDict, total=False):
"""Parameters for Elasticsearch client initialization."""
hosts: NotRequired[list[str]]
cloud_id: NotRequired[str]
username: NotRequired[str]
password: NotRequired[str]
api_key: NotRequired[str]
use_ssl: NotRequired[bool]
verify_certs: NotRequired[bool]
ca_certs: NotRequired[str]
timeout: NotRequired[int]
class ElasticsearchIndexSettings(TypedDict, total=False):
"""Settings for Elasticsearch index creation."""
number_of_shards: NotRequired[int]
number_of_replicas: NotRequired[int]
refresh_interval: NotRequired[str]
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
"""Parameters for creating Elasticsearch collections/indices."""
collection_name: str
index_settings: NotRequired[ElasticsearchIndexSettings]
vector_dimension: NotRequired[int]
similarity: NotRequired[str]
class EmbeddingFunction(Protocol):
"""Protocol for embedding functions that convert text to vectors."""
def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""
async def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector asynchronously.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -0,0 +1,186 @@
"""Utility functions for Elasticsearch RAG implementation."""
import hashlib
from typing import Any, TypeGuard
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
)
from crewai.rag.types import BaseRecord, SearchResult
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
except ImportError:
Elasticsearch = None
AsyncElasticsearch = None
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is a sync Elasticsearch client."""
if Elasticsearch is None:
return False
return isinstance(client, Elasticsearch)
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is an async Elasticsearch client."""
if AsyncElasticsearch is None:
return False
return isinstance(client, AsyncElasticsearch)
def _is_async_embedding_function(
func: EmbeddingFunction | AsyncEmbeddingFunction,
) -> TypeGuard[AsyncEmbeddingFunction]:
"""Type guard to check if the embedding function is async."""
import inspect
return inspect.iscoroutinefunction(func)
def _generate_doc_id(content: str) -> str:
"""Generate a document ID from content using SHA256 hash."""
return hashlib.sha256(content.encode()).hexdigest()
def _prepare_document_for_elasticsearch(
doc: BaseRecord, embedding: list[float]
) -> dict[str, Any]:
"""Prepare a document for Elasticsearch indexing.
Args:
doc: Document record to prepare.
embedding: Embedding vector for the document.
Returns:
Document formatted for Elasticsearch.
"""
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
es_doc = {
"content": doc["content"],
"content_vector": embedding,
"metadata": doc.get("metadata", {}),
}
return {"id": doc_id, "body": es_doc}
def _process_search_results(
response: dict[str, Any], score_threshold: float | None = None
) -> list[SearchResult]:
"""Process Elasticsearch search response into SearchResult format.
Args:
response: Raw Elasticsearch search response.
score_threshold: Optional minimum score threshold.
Returns:
List of SearchResult dictionaries.
"""
results = []
hits = response.get("hits", {}).get("hits", [])
for hit in hits:
score = hit.get("_score", 0.0)
if score_threshold is not None and score < score_threshold:
continue
source = hit.get("_source", {})
result = SearchResult(
id=hit.get("_id", ""),
content=source.get("content", ""),
metadata=source.get("metadata", {}),
score=score,
)
results.append(result)
return results
def _build_vector_search_query(
query_vector: list[float],
limit: int = 10,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float | None = None,
) -> dict[str, Any]:
"""Build Elasticsearch query for vector similarity search.
Args:
query_vector: Query embedding vector.
limit: Maximum number of results.
metadata_filter: Optional metadata filter.
score_threshold: Optional minimum score threshold.
Returns:
Elasticsearch query dictionary.
"""
query = {
"size": limit,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
"params": {"query_vector": query_vector}
}
}
}
}
if metadata_filter:
bool_query = {
"bool": {
"must": [
query["query"]
],
"filter": []
}
}
for key, value in metadata_filter.items():
bool_query["bool"]["filter"].append({
"term": {f"metadata.{key}": value}
})
query["query"] = bool_query
if score_threshold is not None:
query["min_score"] = score_threshold
return query
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
"""Get Elasticsearch index mapping for vector search.
Args:
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use.
Returns:
Elasticsearch mapping dictionary.
"""
return {
"mappings": {
"properties": {
"content": {
"type": "text",
"analyzer": "standard"
},
"content_vector": {
"type": "dense_vector",
"dims": vector_dimension,
"similarity": similarity
},
"metadata": {
"type": "object",
"dynamic": True
}
}
}
}

View File

@@ -0,0 +1,148 @@
"""Minimal embedding function factory for CrewAI."""
import os
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.embeddings.types import EmbeddingOptions
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
config: Optional configuration - either an EmbeddingOptions object or a dict with:
- provider: The embedding provider to use (default: "openai")
- Any other provider-specific parameters
Returns:
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings (default)
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
- sentence-transformer: Local sentence transformers
- instructor: Instructor embeddings for specialized tasks
- google-palm: Google PaLM embeddings
- google-generativeai: Google Generative AI embeddings
- google-vertex: Google Vertex AI embeddings
- amazon-bedrock: AWS Bedrock embeddings
- jina: Jina AI embeddings
- roboflow: Roboflow embeddings for vision tasks
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples:
# Use default OpenAI with retry logic
>>> embedder = get_embedding_function()
# Use Cohere with dict
>>> embedder = get_embedding_function({
... "provider": "cohere",
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... })
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use local sentence transformers (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "sentence-transformer",
... "model_name": "all-MiniLM-L6-v2"
... })
# Use Ollama for local embeddings
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "onnx"
... })
"""
if config is None:
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# Handle EmbeddingOptions object
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
else:
config_dict = config.copy()
provider = config_dict.pop("provider", "openai")
embedding_functions = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
if provider not in embedding_functions:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}"
)
return embedding_functions[provider](**config_dict)

View File

@@ -0,0 +1,62 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from pydantic import BaseModel, Field, SecretStr
from crewai.rag.types import EmbeddingFunction
EmbeddingProvider = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
"""Supported embedding providers.
These correspond to the embedding functions available in ChromaDB's
embedding_functions module. Each provider has specific requirements
and configuration options.
"""
class EmbeddingOptions(BaseModel):
"""Configuration options for embedding providers.
Generic attributes that can be passed to get_embedding_function
to configure various embedding providers.
"""
provider: EmbeddingProvider = Field(
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
)
model_name: str | None = Field(
default=None, description="Model name for the embedding provider"
)
api_key: SecretStr | None = Field(
default=None, description="API key for the embedding provider"
)
class EmbeddingConfig(BaseModel):
"""Configuration wrapper for embedding functions.
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
to create one. This provides flexibility in how embeddings are configured.
Attributes:
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
"""
function: EmbeddingFunction | EmbeddingOptions

58
src/crewai/rag/factory.py Normal file
View File

@@ -0,0 +1,58 @@
"""Factory functions for creating RAG clients from configuration."""
from typing import cast
from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
ElasticsearchFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.utilities.import_utils import require
def create_client(config: RagConfigType) -> BaseClient:
"""Create a client from configuration using the appropriate factory.
Args:
config: The RAG client configuration.
Returns:
The created client instance.
Raises:
ValueError: If the configuration provider is not supported.
"""
if config.provider == "chromadb":
chromadb_mod = cast(
ChromaFactoryModule,
require(
"crewai.rag.chromadb.factory",
purpose="The 'chromadb' provider",
),
)
return chromadb_mod.create_client(config)
if config.provider == "qdrant":
qdrant_mod = cast(
QdrantFactoryModule,
require(
"crewai.rag.qdrant.factory",
purpose="The 'qdrant' provider",
),
)
return qdrant_mod.create_client(config)
if config.provider == "elasticsearch":
elasticsearch_mod = cast(
ElasticsearchFactoryModule,
require(
"crewai.rag.elasticsearch.factory",
purpose="The 'elasticsearch' provider",
),
)
return elasticsearch_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -0,0 +1 @@
"""Qdrant vector database client implementation."""

View File

@@ -0,0 +1,501 @@
"""Qdrant client implementation."""
from typing import Any, cast
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
QdrantClientType,
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params,
_process_search_results,
)
from crewai.rag.types import SearchResult
class QdrantClient(BaseClient):
"""Qdrant implementation of the BaseClient protocol.
Provides vector database operations for Qdrant, supporting both
synchronous and asynchronous clients.
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
"""
self.client = client
self.embedding_function = embedding_function
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Raises:
ValueError: If collection with the same name already exists.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="create_collection",
expected_client="QdrantClient",
alt_method="acreate_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' already exists")
params = _get_collection_params(kwargs)
self.client.create_collection(**params)
async def acreate_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> None:
"""Create a new collection in Qdrant asynchronously.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Raises:
ValueError: If collection with the same name already exists.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="acreate_collection",
expected_client="AsyncQdrantClient",
alt_method="create_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' already exists")
params = _get_collection_params(kwargs)
await self.client.create_collection(**params)
def get_or_create_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist.
Keyword Args:
collection_name: Name of the collection to get or create.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Returns:
Collection info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="get_or_create_collection",
expected_client="QdrantClient",
alt_method="aget_or_create_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if self.client.collection_exists(collection_name):
return self.client.get_collection(collection_name)
params = _get_collection_params(kwargs)
self.client.create_collection(**params)
return self.client.get_collection(collection_name)
async def aget_or_create_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: Name of the collection to get or create.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Returns:
Collection info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aget_or_create_collection",
expected_client="AsyncQdrantClient",
alt_method="get_or_create_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if await self.client.collection_exists(collection_name):
return await self.client.get_collection(collection_name)
params = _get_collection_params(kwargs)
await self.client.create_collection(**params)
return await self.client.get_collection(collection_name)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="add_documents",
expected_client="QdrantClient",
alt_method="aadd_documents",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aadd_documents",
expected_client="AsyncQdrantClient",
alt_method="add_documents",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="search",
expected_client="QdrantClient",
alt_method="asearch",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync search. "
"Use asearch instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_kwargs = _prepare_search_params(
collection_name=collection_name,
query_embedding=query_embedding,
limit=limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
response = self.client.query_points(**search_kwargs)
return _process_search_results(response)
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="asearch",
expected_client="AsyncQdrantClient",
alt_method="search",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_kwargs = _prepare_search_params(
collection_name=collection_name,
query_embedding=query_embedding,
limit=limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
response = await self.client.query_points(**search_kwargs)
return _process_search_results(response)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="delete_collection",
expected_client="QdrantClient",
alt_method="adelete_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
self.client.delete_collection(collection_name=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="adelete_collection",
expected_client="AsyncQdrantClient",
alt_method="delete_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
await self.client.delete_collection(collection_name=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="reset",
expected_client="QdrantClient",
alt_method="areset",
alt_client="AsyncQdrantClient",
)
collections_response = self.client.get_collections()
for collection in collections_response.collections:
self.client.delete_collection(collection_name=collection.name)
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="areset",
expected_client="AsyncQdrantClient",
alt_method="reset",
alt_client="QdrantClient",
)
collections_response = await self.client.get_collections()
for collection in collections_response.collections:
await self.client.delete_collection(collection_name=collection.name)

View File

@@ -0,0 +1,54 @@
"""Qdrant configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
Returns:
Default options with file-based storage.
"""
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
"""Create default Qdrant embedding function.
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embeddings = list(model.embed([text]))
return embeddings[0].tolist() if embeddings else []
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
@pyd_dataclass(frozen=True)
class QdrantConfig(BaseRagConfig):
"""Configuration for Qdrant client."""
provider: Literal["qdrant"] = field(default="qdrant", init=False)
options: QdrantClientParams = field(default_factory=_default_options)
embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,12 @@
"""Constants for Qdrant implementation."""
import os
from typing import Final
from qdrant_client.models import Distance, VectorParams
from crewai.utilities.paths import db_storage_path
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")

View File

@@ -0,0 +1,21 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
def create_client(config: QdrantConfig) -> QdrantClient:
"""Create a Qdrant client from configuration.
Args:
config: The Qdrant configuration.
Returns:
A configured QdrantClient instance.
"""
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
)

View File

@@ -0,0 +1,155 @@
"""Type definitions specific to Qdrant implementation."""
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired, TypedDict
import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
FieldCondition,
Filter,
HasIdCondition,
HasVectorCondition,
HnswConfigDiff,
InitFrom,
IsEmptyCondition,
IsNullCondition,
NestedCondition,
OptimizersConfigDiff,
QuantizationConfig,
ShardingMethod,
SparseVectorsConfig,
VectorsConfig,
WalConfigDiff,
)
from crewai.rag.core.base_client import BaseCollectionParams
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
FilterCondition = BasicConditions | StructuralConditions | Filter
MetadataFilterValue = bool | int | str
MetadataFilter = dict[str, MetadataFilterValue]
class EmbeddingFunction(Protocol):
"""Protocol for embedding functions that convert text to vectors."""
def __call__(self, text: str) -> QueryEmbedding:
"""Convert text to embedding vector.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats or numpy array.
"""
...
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
This allows Pydantic to handle Qdrant's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""
async def __call__(self, text: str) -> QueryEmbedding:
"""Convert text to embedding vector asynchronously.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats or numpy array.
"""
...
class QdrantClientParams(TypedDict, total=False):
"""Parameters for QdrantClient initialization.
Notes:
Need to implement in factory or remove.
"""
location: str | None
url: str | None
port: int
grpc_port: int
prefer_grpc: bool
https: bool | None
api_key: str | None
prefix: str | None
timeout: int | None
host: str | None
path: str | None
force_disable_check_same_thread: bool
grpc_options: dict[str, Any] | None
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
cloud_inference: bool
local_inference_batch_size: int | None
check_compatibility: bool
class CommonCreateFields(TypedDict, total=False):
"""Fields shared between high-level and direct create_collection params."""
vectors_config: VectorsConfig
sparse_vectors_config: SparseVectorsConfig
shard_number: Annotated[int, "Number of shards (default: 1)"]
sharding_method: ShardingMethod
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
hnsw_config: HnswConfigDiff
optimizers_config: OptimizersConfigDiff
wal_config: WalConfigDiff
quantization_config: QuantizationConfig
init_from: InitFrom | str
timeout: Annotated[int, "Operation timeout in seconds"]
class QdrantCollectionCreateParams(
BaseCollectionParams, CommonCreateFields, total=False
):
"""High-level parameters for creating a Qdrant collection."""
pass
class CreateCollectionParams(CommonCreateFields, total=False):
"""Parameters for qdrant_client.create_collection."""
collection_name: str
class PreparedSearchParams(TypedDict):
"""Type definition for prepared Qdrant search parameters."""
collection_name: str
query: list[float]
limit: Annotated[int, "Max results to return"]
with_payload: Annotated[bool, "Include payload in results"]
with_vectors: Annotated[bool, "Include vectors in results"]
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
query_filter: NotRequired[Filter]

View File

@@ -0,0 +1,228 @@
"""Utility functions for Qdrant operations."""
import asyncio
from typing import TypeGuard
from uuid import uuid4
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
FieldCondition,
Filter,
MatchValue,
PointStruct,
QueryResponse,
)
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
CreateCollectionParams,
EmbeddingFunction,
FilterCondition,
MetadataFilter,
PreparedSearchParams,
QdrantClientType,
QdrantCollectionCreateParams,
QueryEmbedding,
)
from crewai.rag.types import SearchResult, BaseRecord
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
"""Convert embedding to list[float] format if needed.
Args:
embedding: Embedding vector as list or numpy array.
Returns:
Embedding as list[float].
"""
if not isinstance(embedding, list):
return embedding.tolist()
return embedding
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
"""Type guard to check if the client is a synchronous QdrantClient.
Args:
client: The client to check.
Returns:
True if the client is a QdrantClient, False otherwise.
"""
return isinstance(client, SyncQdrantClient)
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
Args:
client: The client to check.
Returns:
True if the client is an AsyncQdrantClient, False otherwise.
"""
return isinstance(client, AsyncQdrantClient)
def _is_async_embedding_function(
func: EmbeddingFunction | AsyncEmbeddingFunction,
) -> TypeGuard[AsyncEmbeddingFunction]:
"""Type guard to check if the embedding function is async.
Args:
func: The embedding function to check.
Returns:
True if the function is async, False otherwise.
"""
return asyncio.iscoroutinefunction(func)
def _get_collection_params(
kwargs: QdrantCollectionCreateParams,
) -> CreateCollectionParams:
"""Extract collection creation parameters from kwargs."""
params: CreateCollectionParams = {
"collection_name": kwargs["collection_name"],
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
}
if "sparse_vectors_config" in kwargs:
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
if "shard_number" in kwargs:
params["shard_number"] = kwargs["shard_number"]
if "sharding_method" in kwargs:
params["sharding_method"] = kwargs["sharding_method"]
if "replication_factor" in kwargs:
params["replication_factor"] = kwargs["replication_factor"]
if "write_consistency_factor" in kwargs:
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
if "on_disk_payload" in kwargs:
params["on_disk_payload"] = kwargs["on_disk_payload"]
if "hnsw_config" in kwargs:
params["hnsw_config"] = kwargs["hnsw_config"]
if "optimizers_config" in kwargs:
params["optimizers_config"] = kwargs["optimizers_config"]
if "wal_config" in kwargs:
params["wal_config"] = kwargs["wal_config"]
if "quantization_config" in kwargs:
params["quantization_config"] = kwargs["quantization_config"]
if "init_from" in kwargs:
params["init_from"] = kwargs["init_from"]
if "timeout" in kwargs:
params["timeout"] = kwargs["timeout"]
return params
def _prepare_search_params(
collection_name: str,
query_embedding: QueryEmbedding,
limit: int,
score_threshold: float | None,
metadata_filter: MetadataFilter | None,
) -> PreparedSearchParams:
"""Prepare search parameters for Qdrant query_points.
Args:
collection_name: Name of the collection to search.
query_embedding: Embedding vector for the query.
limit: Maximum number of results.
score_threshold: Optional minimum similarity score.
metadata_filter: Optional metadata filters.
Returns:
Dictionary of parameters for query_points method.
"""
query_vector = _ensure_list_embedding(query_embedding)
search_kwargs: PreparedSearchParams = {
"collection_name": collection_name,
"query": query_vector,
"limit": limit,
"with_payload": True,
"with_vectors": False,
}
if score_threshold is not None:
search_kwargs["score_threshold"] = score_threshold
if metadata_filter:
filter_conditions: list[FilterCondition] = []
for key, value in metadata_filter.items():
filter_conditions.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
search_kwargs["query_filter"] = Filter(must=filter_conditions)
return search_kwargs
def _normalize_qdrant_score(score: float) -> float:
"""Normalize Qdrant cosine similarity score to [0, 1] range.
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
Args:
score: Raw cosine similarity score from Qdrant [-1, 1].
Returns:
Normalized score in [0, 1] range where 1 is most similar.
"""
normalized = (score + 1.0) / 2.0
return max(0.0, min(1.0, normalized))
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
"""Process Qdrant search response into SearchResult format.
Args:
response: Response from Qdrant query_points method.
Returns:
List of SearchResult dictionaries.
"""
results: list[SearchResult] = []
for point in response.points:
payload = point.payload or {}
score = _normalize_qdrant_score(score=point.score)
result: SearchResult = {
"id": str(point.id),
"content": payload.get("content", ""),
"metadata": {k: v for k, v in payload.items() if k != "content"},
"score": score,
}
results.append(result)
return results
def _create_point_from_document(
doc: BaseRecord, embedding: QueryEmbedding
) -> PointStruct:
"""Create a PointStruct from a document and its embedding.
Args:
doc: Document dictionary containing content, metadata, and optional doc_id.
embedding: The embedding vector for the document content.
Returns:
PointStruct ready to be upserted to Qdrant.
"""
doc_id = doc.get("doc_id", str(uuid4()))
vector = _ensure_list_embedding(embedding)
metadata = doc.get("metadata", {})
if isinstance(metadata, list):
metadata = metadata[0] if metadata else {}
elif not isinstance(metadata, dict):
metadata = dict(metadata) if metadata else {}
return PointStruct(
id=doc_id,
vector=vector,
payload={"content": doc["content"], **metadata},
)

50
src/crewai/rag/types.py Normal file
View File

@@ -0,0 +1,50 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import TypeAlias, Any
from typing_extensions import Required, TypedDict
class BaseRecord(TypedDict, total=False):
"""A typed dictionary representing a document record.
Attributes:
doc_id: Optional unique identifier for the document. If not provided,
a content-based ID will be generated using SHA256 hash.
content: The text content of the document (required)
metadata: Optional metadata associated with the document
"""
doc_id: str
content: Required[str]
metadata: (
Mapping[str, str | int | float | bool]
| list[Mapping[str, str | int | float | bool]]
)
DenseVector: TypeAlias = list[float]
IntVector: TypeAlias = list[int]
EmbeddingFunction: TypeAlias = Callable[..., Any]
class SearchResult(TypedDict):
"""Standard search result format for vector store queries.
This provides a consistent interface for search results across different
vector store implementations. Each implementation should convert their
native result format to this standard format.
Attributes:
id: Unique identifier of the document
content: The text content of the document
metadata: Optional metadata associated with the document
score: Similarity score (higher is better, typically between 0 and 1)
"""
id: str
content: str
metadata: dict[str, Any]
score: float

View File

@@ -4,6 +4,7 @@ import json
import logging
import threading
import uuid
import warnings
from concurrent.futures import Future
from copy import copy
from hashlib import md5
@@ -72,6 +73,10 @@ class Task(BaseModel):
output_pydantic: Pydantic model for task output.
security_config: Security configuration including fingerprinting.
tools: List of tools/resources limited for task execution.
allow_crewai_trigger_context: Optional flag to control crewai_trigger_payload injection.
None (default): Auto-inject for first task only.
True: Always inject trigger payload for this task.
False: Never inject trigger payload, even for first task.
"""
__hash__ = object.__hash__ # type: ignore
@@ -153,8 +158,13 @@ class Task(BaseModel):
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
max_retries: Optional[int] = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0"
)
guardrail_max_retries: int = Field(
default=3,
description="Maximum number of retries when guardrail fails"
)
retry_count: int = Field(default=0, description="Current number of retries")
start_time: Optional[datetime.datetime] = Field(
@@ -163,6 +173,10 @@ class Task(BaseModel):
end_time: Optional[datetime.datetime] = Field(
default=None, description="End time of the task execution"
)
allow_crewai_trigger_context: Optional[bool] = Field(
default=None,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
)
model_config = {"arbitrary_types_allowed": True}
@field_validator("guardrail")
@@ -346,6 +360,18 @@ class Task(BaseModel):
)
return self
@model_validator(mode="after")
def handle_max_retries_deprecation(self):
if self.max_retries is not None:
warnings.warn(
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
"Please use 'guardrail_max_retries' instead.",
DeprecationWarning,
stacklevel=2
)
self.guardrail_max_retries = self.max_retries
return self
def execute_sync(
self,
agent: Optional[BaseAgent] = None,
@@ -425,7 +451,7 @@ class Task(BaseModel):
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name,
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
@@ -442,9 +468,9 @@ class Task(BaseModel):
retry_count=self.retry_count,
)
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.max_retries} retries. "
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
@@ -548,12 +574,23 @@ class Task(BaseModel):
str: The formatted prompt string containing the task description,
expected output, and optional markdown formatting instructions.
"""
tasks_slices = [self.description]
description = self.description
should_inject = self.allow_crewai_trigger_context
if should_inject and self.agent:
crew = getattr(self.agent, "crew", None)
if crew and hasattr(crew, "_inputs") and crew._inputs:
trigger_payload = crew._inputs.get("crewai_trigger_payload")
if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}"
tasks_slices = [description]
output = self.i18n.slice("expected_output").format(
expected_output=self.expected_output
)
tasks_slices = [self.description, output]
tasks_slices = [description, output]
if self.markdown:
markdown_instruction = """Your final answer MUST be formatted in Markdown syntax.
@@ -761,7 +798,9 @@ Follow these guidelines:
if self.create_directory and not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
elif not self.create_directory and not directory.exists():
raise RuntimeError(f"Directory {directory} does not exist and create_directory is False")
raise RuntimeError(
f"Directory {directory} does not exist and create_directory is False"
)
with resolved_path.open("w", encoding="utf-8") as file:
if isinstance(result, dict):

View File

@@ -14,12 +14,14 @@ from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
class EnvVar(BaseModel):
name: str
description: str
required: bool = True
default: Optional[str] = None
class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(PydanticBaseModel):
pass
@@ -108,7 +110,7 @@ class BaseTool(BaseModel, ABC):
def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
self._set_args_schema()
return CrewStructuredTool(
structured_tool = CrewStructuredTool(
name=self.name,
description=self.description,
args_schema=self.args_schema,
@@ -117,6 +119,8 @@ class BaseTool(BaseModel, ABC):
max_usage_count=self.max_usage_count,
current_usage_count=self.current_usage_count,
)
structured_tool._original_tool = self
return structured_tool
@classmethod
def from_langchain(cls, tool: Any) -> "BaseTool":
@@ -276,7 +280,9 @@ def to_langchain(
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.

View File

@@ -10,6 +10,17 @@ from pydantic import BaseModel, Field, create_model
from crewai.utilities.logger import Logger
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
class ToolUsageLimitExceeded(Exception):
"""Exception raised when a tool has reached its maximum usage limit."""
pass
class CrewStructuredTool:
"""A structured tool that can operate on any number of inputs.
@@ -18,6 +29,8 @@ class CrewStructuredTool:
that integrates better with CrewAI's ecosystem.
"""
_original_tool: BaseTool | None = None
def __init__(
self,
name: str,
@@ -47,6 +60,7 @@ class CrewStructuredTool:
self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
self._original_tool = None
# Validate the function signature matches the schema
self._validate_function_signature()
@@ -219,16 +233,26 @@ class CrewStructuredTool:
"""
parsed_args = self._parse_args(input)
if inspect.iscoroutinefunction(self.func):
return await self.func(**parsed_args, **kwargs)
else:
# Run sync functions in a thread pool
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
if self.has_reached_max_usage_count():
raise ToolUsageLimitExceeded(
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
)
self._increment_usage_count()
try:
if inspect.iscoroutinefunction(self.func):
return await self.func(**parsed_args, **kwargs)
else:
# Run sync functions in a thread pool
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
)
except Exception:
raise
def _run(self, *args, **kwargs) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
@@ -242,10 +266,22 @@ class CrewStructuredTool:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)
if self.has_reached_max_usage_count():
raise ToolUsageLimitExceeded(
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
)
self._increment_usage_count()
if inspect.iscoroutinefunction(self.func):
result = asyncio.run(self.func(**parsed_args, **kwargs))
return result
try:
result = self.func(**parsed_args, **kwargs)
except Exception:
raise
result = self.func(**parsed_args, **kwargs)
if asyncio.iscoroutine(result):
@@ -253,6 +289,19 @@ class CrewStructuredTool:
return result
def has_reached_max_usage_count(self) -> bool:
"""Check if the tool has reached its maximum usage count."""
return (
self.max_usage_count is not None
and self.current_usage_count >= self.max_usage_count
)
def _increment_usage_count(self) -> None:
"""Increment the usage count."""
self.current_usage_count += 1
if self._original_tool is not None:
self._original_tool.current_usage_count = self.current_usage_count
@property
def args(self) -> dict:
"""Get the tool's input arguments schema."""

View File

@@ -178,9 +178,11 @@ class ToolUsage:
if self.agent.fingerprint:
event_data.update(self.agent.fingerprint)
if self.task:
event_data["task_name"] = self.task.name or self.task.description
event_data["task_id"] = str(self.task.id)
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
crewai_event_bus.emit(self,ToolUsageStartedEvent(**event_data))
started_at = time.time()
from_cache = False
result = None # type: ignore
@@ -311,12 +313,15 @@ class ToolUsage:
if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data)
if available_tool and hasattr(available_tool, 'current_usage_count'):
if available_tool and hasattr(available_tool, "current_usage_count"):
available_tool.current_usage_count += 1
if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None:
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
):
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue"
color="blue",
)
return result
@@ -350,20 +355,20 @@ class ToolUsage:
calling.arguments == last_tool_usage.arguments
)
return False
def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None:
"""Check if tool has reached its usage limit.
Args:
tool: The tool to check
tool_name: The name of the tool (used for error message)
Returns:
Error message if limit reached, None otherwise
"""
if (
hasattr(tool, 'max_usage_count')
and tool.max_usage_count is not None
hasattr(tool, "max_usage_count")
and tool.max_usage_count is not None
and tool.current_usage_count >= tool.max_usage_count
):
return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore."
@@ -605,6 +610,9 @@ class ToolUsage:
"output": result,
}
)
if self.task:
event_data["task_id"] = str(self.task.id)
event_data["task_name"] = self.task.name or self.task.description
crewai_event_bus.emit(self, ToolUsageFinishedEvent(**event_data))
def _prepare_event_data(

View File

@@ -1,9 +1,10 @@
import os
import re
import portalocker
from chromadb import PersistentClient
from hashlib import md5
from typing import Optional
from crewai.utilities.paths import db_storage_path
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
@@ -27,7 +28,9 @@ def is_ipv4_pattern(name: str) -> bool:
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH) -> str:
def sanitize_collection_name(
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
@@ -72,7 +75,8 @@ def create_persistent_client(path: str, **kwargs):
concurrent creations. Works for both multi-threads and multi-processes
environments.
"""
lockfile = f"chromadb-{md5(path.encode(), usedforsecurity=False).hexdigest()}.lock"
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(path=path, **kwargs)

View File

@@ -11,7 +11,9 @@ class BaseEvent(BaseModel):
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
type: str
source_fingerprint: Optional[str] = None # UUID string of the source entity
source_type: Optional[str] = None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
source_type: Optional[str] = (
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
)
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
def to_json(self, exclude: set[str] | None = None):
@@ -25,3 +27,20 @@ class BaseEvent(BaseModel):
dict: A JSON-serializable dictionary.
"""
return to_serializable(self, exclude=exclude)
def _set_task_params(self, data: Dict[str, Any]):
if "from_task" in data and (task := data["from_task"]):
self.task_id = task.id
self.task_name = task.name or task.description
self.from_task = None
def _set_agent_params(self, data: Dict[str, Any]):
task = data.get("from_task", None)
agent = task.agent if task else data.get("from_agent", None)
if not agent:
return
self.agent_id = agent.id
self.agent_role = agent.role
self.from_agent = None

View File

@@ -161,8 +161,10 @@ class EventListener(BaseEventListener):
def on_task_started(source, event: TaskStartedEvent):
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span
# Pass both task ID and task name (if set)
task_name = source.name if hasattr(source, 'name') and source.name else None
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id
self.formatter.current_crew_tree, source.id, task_name
)
@crewai_event_bus.on(TaskCompletedEvent)
@@ -173,11 +175,14 @@ class EventListener(BaseEventListener):
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, 'name') and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"completed",
task_name
)
@crewai_event_bus.on(TaskFailedEvent)
@@ -188,11 +193,14 @@ class EventListener(BaseEventListener):
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, 'name') and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"failed",
task_name
)
# ----------- AGENT EVENTS -----------

View File

@@ -40,17 +40,20 @@ class TraceBatch:
class TraceBatchManager:
"""Single responsibility: Manage batches and event buffering"""
is_current_batch_ephemeral: bool = False
trace_batch_id: Optional[str] = None
current_batch: Optional[TraceBatch] = None
event_buffer: List[TraceEvent] = []
execution_start_times: Dict[str, datetime] = {}
batch_owner_type: Optional[str] = None
batch_owner_id: Optional[str] = None
def __init__(self):
try:
self.plus_api = PlusAPI(api_key=get_auth_token())
except AuthError:
self.plus_api = PlusAPI(api_key="")
self.trace_batch_id: Optional[str] = None # Backend ID
self.current_batch: Optional[TraceBatch] = None
self.event_buffer: List[TraceEvent] = []
self.execution_start_times: Dict[str, datetime] = {}
def initialize_batch(
self,
user_context: Dict[str, str],
@@ -62,6 +65,7 @@ class TraceBatchManager:
user_context=user_context, execution_metadata=execution_metadata
)
self.event_buffer.clear()
self.is_current_batch_ephemeral = use_ephemeral
self.record_start_time("execution")
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
@@ -136,7 +140,7 @@ class TraceBatchManager:
"""Add event to buffer"""
self.event_buffer.append(trace_event)
def _send_events_to_backend(self, ephemeral: bool = True):
def _send_events_to_backend(self):
"""Send buffered events to backend"""
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
return
@@ -156,7 +160,7 @@ class TraceBatchManager:
response = (
self.plus_api.send_ephemeral_trace_events(self.trace_batch_id, payload)
if ephemeral
if self.is_current_batch_ephemeral
else self.plus_api.send_trace_events(self.trace_batch_id, payload)
)
@@ -170,29 +174,31 @@ class TraceBatchManager:
except Exception as e:
logger.error(f"❌ Error sending events to backend: {str(e)}")
def finalize_batch(self, ephemeral: bool = True) -> Optional[TraceBatch]:
def finalize_batch(self) -> Optional[TraceBatch]:
"""Finalize batch and return it for sending"""
if not self.current_batch:
return None
if self.event_buffer:
self._send_events_to_backend(ephemeral)
self._finalize_backend_batch(ephemeral)
self.current_batch.events = self.event_buffer.copy()
if self.event_buffer:
self._send_events_to_backend()
self._finalize_backend_batch()
finalized_batch = self.current_batch
self.batch_owner_type = None
self.batch_owner_id = None
self.current_batch = None
self.event_buffer.clear()
self.trace_batch_id = None
self.is_current_batch_ephemeral = False
self._cleanup_batch_data()
return finalized_batch
def _finalize_backend_batch(self, ephemeral: bool = True):
def _finalize_backend_batch(self):
"""Send batch finalization to backend"""
if not self.plus_api or not self.trace_batch_id:
return
@@ -210,7 +216,7 @@ class TraceBatchManager:
self.plus_api.finalize_ephemeral_trace_batch(
self.trace_batch_id, payload
)
if ephemeral
if self.is_current_batch_ephemeral
else self.plus_api.finalize_trace_batch(self.trace_batch_id, payload)
)
@@ -219,7 +225,7 @@ class TraceBatchManager:
console = Console()
return_link = (
f"{CREWAI_BASE_URL}/crewai_plus/trace_batches/{self.trace_batch_id}"
if not ephemeral and access_code
if not self.is_current_batch_ephemeral and access_code is None
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
)
panel = Panel(

View File

@@ -75,10 +75,18 @@ class TraceCollectionListener(BaseEventListener):
Trace collection listener that orchestrates trace collection
"""
complex_events = ["task_started", "llm_call_started", "llm_call_completed"]
complex_events = [
"task_started",
"task_completed",
"llm_call_started",
"llm_call_completed",
"agent_execution_started",
"agent_execution_completed",
]
_instance = None
_initialized = False
_listeners_setup = False
def __new__(cls, batch_manager=None):
if cls._instance is None:
@@ -116,10 +124,15 @@ class TraceCollectionListener(BaseEventListener):
def setup_listeners(self, crewai_event_bus):
"""Setup event listeners - delegates to specific handlers"""
if self._listeners_setup:
return
self._register_flow_event_handlers(crewai_event_bus)
self._register_context_event_handlers(crewai_event_bus)
self._register_action_event_handlers(crewai_event_bus)
self._listeners_setup = True
def _register_flow_event_handlers(self, event_bus):
"""Register handlers for flow events"""
@@ -148,7 +161,8 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event):
self._handle_trace_event("flow_finished", source, event)
self.batch_manager.finalize_batch()
if self.batch_manager.batch_owner_type == "flow":
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
@@ -166,7 +180,8 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event):
self._handle_trace_event("crew_kickoff_completed", source, event)
self.batch_manager.finalize_batch(ephemeral=True)
if self.batch_manager.batch_owner_type == "crew":
self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event):
@@ -218,7 +233,7 @@ class TraceCollectionListener(BaseEventListener):
self._handle_trace_event("llm_guardrail_completed", source, event)
def _register_action_event_handlers(self, event_bus):
"""Register handlers for action events (LLM calls, tool usage, memory)"""
"""Register handlers for action events (LLM calls, tool usage)"""
@event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event):
@@ -289,6 +304,9 @@ class TraceCollectionListener(BaseEventListener):
"crewai_version": get_crewai_version(),
}
self.batch_manager.batch_owner_type = "crew"
self.batch_manager.batch_owner_id = getattr(source, "id", str(uuid.uuid4()))
self._initialize_batch(user_context, execution_metadata)
def _initialize_flow_batch(self, source: Any, event: Any):
@@ -301,6 +319,9 @@ class TraceCollectionListener(BaseEventListener):
"execution_type": "flow",
}
self.batch_manager.batch_owner_type = "flow"
self.batch_manager.batch_owner_id = getattr(source, "id", str(uuid.uuid4()))
self._initialize_batch(user_context, execution_metadata)
def _initialize_batch(
@@ -358,12 +379,44 @@ class TraceCollectionListener(BaseEventListener):
return {
"task_description": event.task.description,
"expected_output": event.task.expected_output,
"task_name": event.task.name,
"task_name": event.task.name or event.task.description,
"context": event.context,
"agent": source.agent.role,
"agent_role": source.agent.role,
"task_id": str(event.task.id),
}
elif event_type == "task_completed":
return {
"task_description": event.task.description if event.task else None,
"task_name": event.task.name or event.task.description
if event.task
else None,
"task_id": str(event.task.id) if event.task else None,
"output_raw": event.output.raw if event.output else None,
"output_format": str(event.output.output_format)
if event.output
else None,
"agent_role": event.output.agent if event.output else None,
}
elif event_type == "agent_execution_started":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "agent_execution_completed":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "llm_call_started":
return self._safe_serialize_to_dict(event)
event_data = self._safe_serialize_to_dict(event)
event_data["task_name"] = (
event.task_name or event.task_description
if hasattr(event, "task_name") and event.task_name
else None
)
return event_data
elif event_type == "llm_call_completed":
return self._safe_serialize_to_dict(event)
else:

Some files were not shown because too many files have changed in this diff Show More