mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 16:22:49 +00:00
Merge branch 'main' into devin/1763567753-fix-task-planner-ordering
This commit is contained in:
@@ -62,9 +62,9 @@
|
||||
With over 100,000 developers certified through our community courses at [learn.crewai.com](https://learn.crewai.com), CrewAI is rapidly becoming the
|
||||
standard for enterprise-ready AI automation.
|
||||
|
||||
# CrewAI AMP Suite
|
||||
# CrewAI AOP Suite
|
||||
|
||||
CrewAI AMP Suite is a comprehensive bundle tailored for organizations that require secure, scalable, and easy-to-manage agent-driven automation.
|
||||
CrewAI AOP Suite is a comprehensive bundle tailored for organizations that require secure, scalable, and easy-to-manage agent-driven automation.
|
||||
|
||||
You can try one part of the suite the [Crew Control Plane for free](https://app.crewai.com)
|
||||
|
||||
@@ -76,9 +76,9 @@ You can try one part of the suite the [Crew Control Plane for free](https://app.
|
||||
- **Advanced Security**: Built-in robust security and compliance measures ensuring safe deployment and management.
|
||||
- **Actionable Insights**: Real-time analytics and reporting to optimize performance and decision-making.
|
||||
- **24/7 Support**: Dedicated enterprise support to ensure uninterrupted operation and quick resolution of issues.
|
||||
- **On-premise and Cloud Deployment Options**: Deploy CrewAI AMP on-premise or in the cloud, depending on your security and compliance requirements.
|
||||
- **On-premise and Cloud Deployment Options**: Deploy CrewAI AOP on-premise or in the cloud, depending on your security and compliance requirements.
|
||||
|
||||
CrewAI AMP is designed for enterprises seeking a powerful, reliable solution to transform complex business processes into efficient,
|
||||
CrewAI AOP is designed for enterprises seeking a powerful, reliable solution to transform complex business processes into efficient,
|
||||
intelligent automations.
|
||||
|
||||
## Table of contents
|
||||
@@ -674,9 +674,9 @@ CrewAI is released under the [MIT License](https://github.com/crewAIInc/crewAI/b
|
||||
|
||||
### Enterprise Features
|
||||
|
||||
- [What additional features does CrewAI AMP offer?](#q-what-additional-features-does-crewai-enterprise-offer)
|
||||
- [Is CrewAI AMP available for cloud and on-premise deployments?](#q-is-crewai-enterprise-available-for-cloud-and-on-premise-deployments)
|
||||
- [Can I try CrewAI AMP for free?](#q-can-i-try-crewai-enterprise-for-free)
|
||||
- [What additional features does CrewAI AOP offer?](#q-what-additional-features-does-crewai-enterprise-offer)
|
||||
- [Is CrewAI AOP available for cloud and on-premise deployments?](#q-is-crewai-enterprise-available-for-cloud-and-on-premise-deployments)
|
||||
- [Can I try CrewAI AOP for free?](#q-can-i-try-crewai-enterprise-for-free)
|
||||
|
||||
### Q: What exactly is CrewAI?
|
||||
|
||||
@@ -732,17 +732,17 @@ A: Check out practical examples in the [CrewAI-examples repository](https://gith
|
||||
|
||||
A: Contributions are warmly welcomed! Fork the repository, create your branch, implement your changes, and submit a pull request. See the Contribution section of the README for detailed guidelines.
|
||||
|
||||
### Q: What additional features does CrewAI AMP offer?
|
||||
### Q: What additional features does CrewAI AOP offer?
|
||||
|
||||
A: CrewAI AMP provides advanced features such as a unified control plane, real-time observability, secure integrations, advanced security, actionable insights, and dedicated 24/7 enterprise support.
|
||||
A: CrewAI AOP provides advanced features such as a unified control plane, real-time observability, secure integrations, advanced security, actionable insights, and dedicated 24/7 enterprise support.
|
||||
|
||||
### Q: Is CrewAI AMP available for cloud and on-premise deployments?
|
||||
### Q: Is CrewAI AOP available for cloud and on-premise deployments?
|
||||
|
||||
A: Yes, CrewAI AMP supports both cloud-based and on-premise deployment options, allowing enterprises to meet their specific security and compliance requirements.
|
||||
A: Yes, CrewAI AOP supports both cloud-based and on-premise deployment options, allowing enterprises to meet their specific security and compliance requirements.
|
||||
|
||||
### Q: Can I try CrewAI AMP for free?
|
||||
### Q: Can I try CrewAI AOP for free?
|
||||
|
||||
A: Yes, you can explore part of the CrewAI AMP Suite by accessing the [Crew Control Plane](https://app.crewai.com) for free.
|
||||
A: Yes, you can explore part of the CrewAI AOP Suite by accessing the [Crew Control Plane](https://app.crewai.com) for free.
|
||||
|
||||
### Q: Does CrewAI support fine-tuning or training custom models?
|
||||
|
||||
@@ -762,7 +762,7 @@ A: CrewAI is highly scalable, supporting simple automations and large-scale ente
|
||||
|
||||
### Q: Does CrewAI offer debugging and monitoring tools?
|
||||
|
||||
A: Yes, CrewAI AMP includes advanced debugging, tracing, and real-time observability features, simplifying the management and troubleshooting of your automations.
|
||||
A: Yes, CrewAI AOP includes advanced debugging, tracing, and real-time observability features, simplifying the management and troubleshooting of your automations.
|
||||
|
||||
### Q: What programming languages does CrewAI support?
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.5.0",
|
||||
"crewai-tools==1.6.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.5.0"
|
||||
__version__ = "1.6.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -951,7 +951,7 @@ class Agent(BaseAgent):
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
"""Get tools from CrewAI AOP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
amp_part = amp_ref.replace("crewai-amp:", "")
|
||||
if "#" in amp_part:
|
||||
@@ -1204,7 +1204,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
|
||||
@@ -83,7 +83,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
knowledge_storage: Custom knowledge storage for the agent.
|
||||
security_config: Security configuration for the agent, including fingerprinting.
|
||||
apps: List of enterprise applications that the agent can access through CrewAI AMP Tools.
|
||||
apps: List of enterprise applications that the agent can access through CrewAI AOP Tools.
|
||||
|
||||
Methods:
|
||||
execute_task(task: Any, context: str | None = None, tools: list[BaseTool] | None = None) -> str:
|
||||
|
||||
@@ -67,7 +67,11 @@ class ProviderFactory:
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||
provider = getattr(
|
||||
module,
|
||||
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||
)
|
||||
|
||||
return cast("BaseProvider", provider(settings))
|
||||
|
||||
@@ -79,7 +83,7 @@ class AuthenticationCommand:
|
||||
|
||||
def login(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
|
||||
console.print("Signing in to CrewAI AOP...\n", style="bold blue")
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
@@ -91,7 +95,7 @@ class AuthenticationCommand:
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
"scope": "openid",
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
@@ -104,9 +108,14 @@ class AuthenticationCommand:
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
||||
|
||||
verification_uri = device_code_data.get(
|
||||
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||
)
|
||||
|
||||
console.print("1. Navigate to: ", verification_uri)
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
webbrowser.open(verification_uri)
|
||||
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
@@ -136,7 +145,7 @@ class AuthenticationCommand:
|
||||
|
||||
self._login_to_tool_repository()
|
||||
|
||||
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
|
||||
console.print("\n[bold green]Welcome to CrewAI AOP![/bold green]\n")
|
||||
return
|
||||
|
||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||
@@ -186,8 +195,9 @@ class AuthenticationCommand:
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})",
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -28,3 +28,6 @@ class BaseProvider(ABC):
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
return []
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import cast
|
||||
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class EntraIdProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._base_url()}/discovery/v2.0/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._base_url()}/v2.0"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return [
|
||||
*super().get_oauth_scopes(),
|
||||
*cast(str, self.settings.extra.get("scope", "")).split(),
|
||||
]
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["scope"]
|
||||
|
||||
def _base_url(self) -> str:
|
||||
return f"https://login.microsoftonline.com/{self.settings.domain}"
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||
) -> dict:
|
||||
) -> Any:
|
||||
"""
|
||||
Verify the token's signature and claims using PyJWT.
|
||||
:param jwt_token: The JWT (JWS) string to validate.
|
||||
@@ -24,6 +26,7 @@ def validate_jwt_token(
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
|
||||
@@ -271,7 +271,7 @@ def update():
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
"""Sign Up/Login to CrewAI AMP."""
|
||||
"""Sign Up/Login to CrewAI AOP."""
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
@@ -460,7 +460,7 @@ def enterprise():
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
"""Configure CrewAI AMP OAuth2 settings from the provided Enterprise URL."""
|
||||
"""Configure CrewAI AOP OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
@@ -101,7 +103,7 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
description="Base URL of the CrewAI AMP instance",
|
||||
description="Base URL of the CrewAI AOP instance",
|
||||
)
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
|
||||
@@ -145,6 +145,7 @@ MODELS = {
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini/gemini-3-pro-preview",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash-lite-001",
|
||||
|
||||
@@ -27,7 +27,7 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
self._update_oauth_settings(enterprise_url, oauth_config)
|
||||
|
||||
console.print(
|
||||
f"✅ Successfully configured CrewAI AMP with OAuth2 settings from {enterprise_url}",
|
||||
f"✅ Successfully configured CrewAI AOP with OAuth2 settings from {enterprise_url}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.5.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.5.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
if login_response.status_code != 200:
|
||||
console.print(
|
||||
"Authentication failed. Verify access to the tool repository, or try `crewai login`. ",
|
||||
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -74,6 +74,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -90,6 +91,14 @@ from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -225,6 +234,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"It may be used to adjust the output of the crew."
|
||||
),
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=False,
|
||||
description="Whether to stream output from the crew execution.",
|
||||
)
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -660,7 +673,43 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput:
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(current_task_info, result_holder)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
def run_crew() -> None:
|
||||
"""Execute the crew and capture the result."""
|
||||
try:
|
||||
self.stream = False
|
||||
crew_result = self.kickoff(inputs=inputs)
|
||||
if isinstance(crew_result, CrewOutput):
|
||||
result_holder.append(crew_result)
|
||||
except Exception as exc:
|
||||
signal_error(state, exc)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
@@ -726,11 +775,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results."""
|
||||
results: list[CrewOutput] = []
|
||||
def kickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results.
|
||||
|
||||
If stream=True, returns a list of CrewStreamingOutput objects that must
|
||||
each be iterated to get stream chunks and access results.
|
||||
"""
|
||||
results: list[CrewOutput | CrewStreamingOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
|
||||
for input_data in inputs:
|
||||
@@ -738,43 +792,161 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
output = crew.kickoff(inputs=input_data)
|
||||
|
||||
if crew.usage_metrics:
|
||||
if not self.stream and crew.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew.usage_metrics)
|
||||
|
||||
results.append(output)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
if not self.stream:
|
||||
self.usage_metrics = total_usage_metrics
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
async def kickoff_async(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution.
|
||||
|
||||
If stream=True, returns a CrewStreamingOutput that can be async-iterated
|
||||
to get stream chunks. After iteration completes, access the final result
|
||||
via .result.
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await asyncio.to_thread(self.kickoff, inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Executes the Crew's workflow for each input asynchronously.
|
||||
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive. After iteration, access results via .results
|
||||
(list of CrewOutput).
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew: Self, input_data: Any) -> CrewOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
"""Run all crew copies and aggregate their streaming outputs."""
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.kickoff_async(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
"""Consume stream chunks and forward to parent queue.
|
||||
|
||||
Args:
|
||||
stream_output: The streaming output to consume.
|
||||
|
||||
Returns:
|
||||
The final CrewOutput result.
|
||||
"""
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
"""Wrap _set_results to match _set_result signature."""
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
|
||||
for i in range(len(inputs))
|
||||
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew in crew_copies:
|
||||
if crew.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew.usage_metrics)
|
||||
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
return list(results)
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
|
||||
@@ -101,24 +101,25 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_instance: EventListener | None = None
|
||||
_initialized: bool = False
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
logger: Logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
knowledge_query_in_progress = False
|
||||
next_chunk: int = 0
|
||||
text_stream: StringIO = StringIO()
|
||||
knowledge_retrieval_in_progress: bool = False
|
||||
knowledge_query_in_progress: bool = False
|
||||
method_branches: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> EventListener:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
def __init__(self) -> None:
|
||||
if not self._initialized:
|
||||
super().__init__()
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
@@ -136,14 +137,14 @@ class EventListener(BaseEventListener):
|
||||
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent) -> None:
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
with self._crew_tree_lock:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
self._crew_tree_lock.notify_all()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event: CrewKickoffCompletedEvent) -> None:
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
final_string_output = event.output.raw
|
||||
self._telemetry.end_crew(source, final_string_output)
|
||||
@@ -157,7 +158,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event: CrewKickoffFailedEvent) -> None:
|
||||
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
|
||||
self.formatter.update_crew_tree(
|
||||
self.formatter.current_crew_tree,
|
||||
event.crew_name or "Crew",
|
||||
@@ -166,23 +167,23 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent) -> None:
|
||||
def on_crew_train_started(_: Any, event: CrewTrainStartedEvent) -> None:
|
||||
self.formatter.handle_crew_train_started(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent) -> None:
|
||||
def on_crew_train_completed(_: Any, event: CrewTrainCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_train_completed(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainFailedEvent)
|
||||
def on_crew_train_failed(source, event: CrewTrainFailedEvent) -> None:
|
||||
def on_crew_train_failed(_: Any, event: CrewTrainFailedEvent) -> None:
|
||||
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(CrewTestResultEvent)
|
||||
def on_crew_test_result(source, event: CrewTestResultEvent) -> None:
|
||||
def on_crew_test_result(source: Any, event: CrewTestResultEvent) -> None:
|
||||
self._telemetry.individual_test_result_span(
|
||||
source.crew,
|
||||
event.quality,
|
||||
@@ -193,7 +194,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TASK EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event: TaskStartedEvent) -> None:
|
||||
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
|
||||
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
|
||||
self.execution_spans[source] = span
|
||||
|
||||
@@ -211,7 +212,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event: TaskCompletedEvent):
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
@@ -229,7 +230,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source, event: TaskFailedEvent):
|
||||
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
if source.agent and source.agent.crew:
|
||||
@@ -249,7 +250,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_execution_started(source, event: AgentExecutionStartedEvent):
|
||||
def on_agent_execution_started(
|
||||
_: Any, event: AgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.create_agent_branch(
|
||||
self.formatter.current_task_branch,
|
||||
event.agent.role,
|
||||
@@ -257,7 +260,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent):
|
||||
def on_agent_execution_completed(
|
||||
_: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.update_agent_status(
|
||||
self.formatter.current_agent_branch,
|
||||
event.agent.role,
|
||||
@@ -268,8 +273,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_execution_started(
|
||||
source, event: LiteAgentExecutionStartedEvent
|
||||
):
|
||||
_: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution started event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="started", **event.agent_info
|
||||
@@ -277,15 +282,17 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_execution_completed(
|
||||
source, event: LiteAgentExecutionCompletedEvent
|
||||
):
|
||||
_: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution completed event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="completed", **event.agent_info
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionErrorEvent)
|
||||
def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent):
|
||||
def on_lite_agent_execution_error(
|
||||
_: Any, event: LiteAgentExecutionErrorEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution error event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"],
|
||||
@@ -297,26 +304,28 @@ class EventListener(BaseEventListener):
|
||||
# ----------- FLOW EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event: FlowCreatedEvent):
|
||||
def on_flow_created(_: Any, event: FlowCreatedEvent) -> None:
|
||||
self._telemetry.flow_creation_span(event.flow_name)
|
||||
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
self.formatter.current_flow_tree = tree
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event: FlowStartedEvent):
|
||||
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
|
||||
self._telemetry.flow_execution_span(
|
||||
event.flow_name, list(source._methods.keys())
|
||||
)
|
||||
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
self.formatter.current_flow_tree = tree
|
||||
self.formatter.start_flow(event.flow_name, str(source.flow_id))
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event: FlowFinishedEvent):
|
||||
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
|
||||
self.formatter.update_flow_status(
|
||||
self.formatter.current_flow_tree, event.flow_name, source.flow_id
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_execution_started(source, event: MethodExecutionStartedEvent):
|
||||
def on_method_execution_started(
|
||||
_: Any, event: MethodExecutionStartedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -327,7 +336,9 @@ class EventListener(BaseEventListener):
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_execution_finished(source, event: MethodExecutionFinishedEvent):
|
||||
def on_method_execution_finished(
|
||||
_: Any, event: MethodExecutionFinishedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -338,7 +349,9 @@ class EventListener(BaseEventListener):
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_execution_failed(source, event: MethodExecutionFailedEvent):
|
||||
def on_method_execution_failed(
|
||||
_: Any, event: MethodExecutionFailedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -351,7 +364,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TOOL USAGE EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_usage_started(source, event: ToolUsageStartedEvent):
|
||||
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_started(
|
||||
event.tool_name,
|
||||
@@ -365,7 +378,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_usage_finished(source, event: ToolUsageFinishedEvent):
|
||||
def on_tool_usage_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_finished(
|
||||
event.tool_name,
|
||||
@@ -378,7 +391,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_error(
|
||||
event.tool_name,
|
||||
@@ -395,7 +408,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- LLM EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
|
||||
self.text_stream = StringIO()
|
||||
self.next_chunk = 0
|
||||
# Capture the returned tool branch and update the current_tool_branch reference
|
||||
thinking_branch = self.formatter.handle_llm_call_started(
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -406,7 +421,8 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.current_tool_branch = thinking_branch
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
||||
def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self.formatter.handle_llm_stream_completed()
|
||||
self.formatter.handle_llm_call_completed(
|
||||
self.formatter.current_tool_branch,
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -414,7 +430,8 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event: LLMCallFailedEvent):
|
||||
def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None:
|
||||
self.formatter.handle_llm_stream_completed()
|
||||
self.formatter.handle_llm_call_failed(
|
||||
self.formatter.current_tool_branch,
|
||||
event.error,
|
||||
@@ -422,16 +439,24 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
|
||||
def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None:
|
||||
self.text_stream.write(event.chunk)
|
||||
self.text_stream.seek(self.next_chunk)
|
||||
self.text_stream.read()
|
||||
self.next_chunk = self.text_stream.tell()
|
||||
|
||||
accumulated_text = self.text_stream.getvalue()
|
||||
self.formatter.handle_llm_stream_chunk(
|
||||
event.chunk,
|
||||
accumulated_text,
|
||||
self.formatter.current_crew_tree,
|
||||
event.call_type,
|
||||
)
|
||||
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_llm_guardrail_started(source, event: LLMGuardrailStartedEvent):
|
||||
def on_llm_guardrail_started(_: Any, event: LLMGuardrailStartedEvent) -> None:
|
||||
guardrail_str = str(event.guardrail)
|
||||
guardrail_name = (
|
||||
guardrail_str[:50] + "..." if len(guardrail_str) > 50 else guardrail_str
|
||||
@@ -440,13 +465,15 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.handle_guardrail_started(guardrail_name, event.retry_count)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def on_llm_guardrail_completed(source, event: LLMGuardrailCompletedEvent):
|
||||
def on_llm_guardrail_completed(
|
||||
_: Any, event: LLMGuardrailCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_guardrail_completed(
|
||||
event.success, event.error, event.retry_count
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def on_crew_test_started(source, event: CrewTestStartedEvent):
|
||||
def on_crew_test_started(source: Any, event: CrewTestStartedEvent) -> None:
|
||||
cloned_crew = source.copy()
|
||||
self._telemetry.test_execution_span(
|
||||
cloned_crew,
|
||||
@@ -460,20 +487,20 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestCompletedEvent)
|
||||
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
|
||||
def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_test_completed(
|
||||
self.formatter.current_flow_tree,
|
||||
event.crew_name or "Crew",
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestFailedEvent)
|
||||
def on_crew_test_failed(source, event: CrewTestFailedEvent):
|
||||
def on_crew_test_failed(_: Any, event: CrewTestFailedEvent) -> None:
|
||||
self.formatter.handle_crew_test_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(
|
||||
source, event: KnowledgeRetrievalStartedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -486,8 +513,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
|
||||
def on_knowledge_retrieval_completed(
|
||||
source, event: KnowledgeRetrievalCompletedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if not self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -499,11 +526,13 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
|
||||
def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent):
|
||||
def on_knowledge_query_started(
|
||||
_: Any, event: KnowledgeQueryStartedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
|
||||
def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent):
|
||||
def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None:
|
||||
self.formatter.handle_knowledge_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -511,13 +540,15 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
|
||||
def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent):
|
||||
def on_knowledge_query_completed(
|
||||
_: Any, event: KnowledgeQueryCompletedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeSearchQueryFailedEvent)
|
||||
def on_knowledge_search_query_failed(
|
||||
source, event: KnowledgeSearchQueryFailedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeSearchQueryFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_knowledge_search_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -527,7 +558,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- REASONING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(source, event: AgentReasoningStartedEvent):
|
||||
def on_agent_reasoning_started(
|
||||
_: Any, event: AgentReasoningStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_started(
|
||||
self.formatter.current_agent_branch,
|
||||
event.attempt,
|
||||
@@ -535,7 +568,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningCompletedEvent)
|
||||
def on_agent_reasoning_completed(source, event: AgentReasoningCompletedEvent):
|
||||
def on_agent_reasoning_completed(
|
||||
_: Any, event: AgentReasoningCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_completed(
|
||||
event.plan,
|
||||
event.ready,
|
||||
@@ -543,7 +578,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningFailedEvent)
|
||||
def on_agent_reasoning_failed(source, event: AgentReasoningFailedEvent):
|
||||
def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None:
|
||||
self.formatter.handle_reasoning_failed(
|
||||
event.error,
|
||||
self.formatter.current_crew_tree,
|
||||
@@ -552,7 +587,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentLogsStartedEvent)
|
||||
def on_agent_logs_started(source, event: AgentLogsStartedEvent):
|
||||
def on_agent_logs_started(_: Any, event: AgentLogsStartedEvent) -> None:
|
||||
self.formatter.handle_agent_logs_started(
|
||||
event.agent_role,
|
||||
event.task_description,
|
||||
@@ -560,7 +595,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentLogsExecutionEvent)
|
||||
def on_agent_logs_execution(source, event: AgentLogsExecutionEvent):
|
||||
def on_agent_logs_execution(_: Any, event: AgentLogsExecutionEvent) -> None:
|
||||
self.formatter.handle_agent_logs_execution(
|
||||
event.agent_role,
|
||||
event.formatted_answer,
|
||||
@@ -568,7 +603,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationStartedEvent)
|
||||
def on_a2a_delegation_started(source, event: A2ADelegationStartedEvent):
|
||||
def on_a2a_delegation_started(_: Any, event: A2ADelegationStartedEvent) -> None:
|
||||
self.formatter.handle_a2a_delegation_started(
|
||||
event.endpoint,
|
||||
event.task_description,
|
||||
@@ -578,7 +613,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationCompletedEvent)
|
||||
def on_a2a_delegation_completed(source, event: A2ADelegationCompletedEvent):
|
||||
def on_a2a_delegation_completed(
|
||||
_: Any, event: A2ADelegationCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_a2a_delegation_completed(
|
||||
event.status,
|
||||
event.result,
|
||||
@@ -587,7 +624,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationStartedEvent)
|
||||
def on_a2a_conversation_started(source, event: A2AConversationStartedEvent):
|
||||
def on_a2a_conversation_started(
|
||||
_: Any, event: A2AConversationStartedEvent
|
||||
) -> None:
|
||||
# Store A2A agent name for display in conversation tree
|
||||
if event.a2a_agent_name:
|
||||
self.formatter._current_a2a_agent_name = event.a2a_agent_name
|
||||
@@ -598,7 +637,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AMessageSentEvent)
|
||||
def on_a2a_message_sent(source, event: A2AMessageSentEvent):
|
||||
def on_a2a_message_sent(_: Any, event: A2AMessageSentEvent) -> None:
|
||||
self.formatter.handle_a2a_message_sent(
|
||||
event.message,
|
||||
event.turn_number,
|
||||
@@ -606,7 +645,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AResponseReceivedEvent)
|
||||
def on_a2a_response_received(source, event: A2AResponseReceivedEvent):
|
||||
def on_a2a_response_received(_: Any, event: A2AResponseReceivedEvent) -> None:
|
||||
self.formatter.handle_a2a_response_received(
|
||||
event.response,
|
||||
event.turn_number,
|
||||
@@ -615,7 +654,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationCompletedEvent)
|
||||
def on_a2a_conversation_completed(source, event: A2AConversationCompletedEvent):
|
||||
def on_a2a_conversation_completed(
|
||||
_: Any, event: A2AConversationCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_a2a_conversation_completed(
|
||||
event.status,
|
||||
event.final_result,
|
||||
@@ -626,7 +667,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||
def on_mcp_connection_started(source, event: MCPConnectionStartedEvent):
|
||||
def on_mcp_connection_started(_: Any, event: MCPConnectionStartedEvent) -> None:
|
||||
self.formatter.handle_mcp_connection_started(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -636,7 +677,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionCompletedEvent)
|
||||
def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent):
|
||||
def on_mcp_connection_completed(
|
||||
_: Any, event: MCPConnectionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_connection_completed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -646,7 +689,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionFailedEvent)
|
||||
def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent):
|
||||
def on_mcp_connection_failed(_: Any, event: MCPConnectionFailedEvent) -> None:
|
||||
self.formatter.handle_mcp_connection_failed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -656,7 +699,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent):
|
||||
def on_mcp_tool_execution_started(
|
||||
_: Any, event: MCPToolExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_started(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
@@ -665,8 +710,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
|
||||
def on_mcp_tool_execution_completed(
|
||||
source, event: MCPToolExecutionCompletedEvent
|
||||
):
|
||||
_: Any, event: MCPToolExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_completed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
@@ -676,7 +721,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
|
||||
def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent):
|
||||
def on_mcp_tool_execution_failed(
|
||||
_: Any, event: MCPToolExecutionFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_failed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
|
||||
@@ -64,6 +64,7 @@ class FlowFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
state: dict[str, Any] | BaseModel
|
||||
|
||||
|
||||
class FlowPlotEvent(FlowEvent):
|
||||
|
||||
@@ -10,7 +10,7 @@ class LLMEventBase(BaseEvent):
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
@@ -84,3 +84,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
||||
type: str = "llm_stream_chunk"
|
||||
chunk: str
|
||||
tool_call: ToolCall | None = None
|
||||
call_type: LLMCallType | None = None
|
||||
|
||||
@@ -21,7 +21,7 @@ class ConsoleFormatter:
|
||||
current_reasoning_branch: Tree | None = None
|
||||
_live_paused: bool = False
|
||||
current_llm_tool_tree: Tree | None = None
|
||||
current_a2a_conversation_branch: Tree | None = None
|
||||
current_a2a_conversation_branch: Tree | str | None = None
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
_pending_a2a_agent_role: str | None = None
|
||||
@@ -39,6 +39,10 @@ class ConsoleFormatter:
|
||||
# Once any non-Tree renderable is printed we stop the Live session so the
|
||||
# final Tree persists on the terminal.
|
||||
self._live: Live | None = None
|
||||
self._streaming_live: Live | None = None
|
||||
self._is_streaming: bool = False
|
||||
self._just_streamed_final_answer: bool = False
|
||||
self._last_stream_call_type: Any = None
|
||||
|
||||
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
||||
"""Create a standardized panel with consistent styling."""
|
||||
@@ -146,6 +150,9 @@ To enable tracing, do any one of these:
|
||||
if len(args) == 1 and isinstance(args[0], Tree):
|
||||
tree = args[0]
|
||||
|
||||
if self._is_streaming:
|
||||
return
|
||||
|
||||
if not self._live:
|
||||
# Start a new Live session for the first tree
|
||||
self._live = Live(tree, console=self.console, refresh_per_second=4)
|
||||
@@ -554,7 +561,7 @@ To enable tracing, do any one of these:
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | str,
|
||||
) -> None:
|
||||
) -> Tree:
|
||||
# Create status content for the tool usage
|
||||
content = self.create_status_content(
|
||||
"Tool Usage Started", tool_name, Status="In Progress", tool_args=tool_args
|
||||
@@ -762,11 +769,14 @@ To enable tracing, do any one of these:
|
||||
thinking_branch_to_remove = None
|
||||
removed = False
|
||||
|
||||
# Method 1: Use the provided tool_branch if it's a thinking node
|
||||
if tool_branch is not None and "Thinking" in str(tool_branch.label):
|
||||
# Method 1: Use the provided tool_branch if it's a thinking/streaming node
|
||||
if tool_branch is not None and (
|
||||
"Thinking" in str(tool_branch.label)
|
||||
or "Streaming" in str(tool_branch.label)
|
||||
):
|
||||
thinking_branch_to_remove = tool_branch
|
||||
|
||||
# Method 2: Fallback - search for any thinking node if tool_branch is None or not thinking
|
||||
# Method 2: Fallback - search for any thinking/streaming node if tool_branch is None or not found
|
||||
if thinking_branch_to_remove is None:
|
||||
parents = [
|
||||
self.current_lite_agent_branch,
|
||||
@@ -777,7 +787,8 @@ To enable tracing, do any one of these:
|
||||
for parent in parents:
|
||||
if isinstance(parent, Tree):
|
||||
for child in parent.children:
|
||||
if "Thinking" in str(child.label):
|
||||
label_str = str(child.label)
|
||||
if "Thinking" in label_str or "Streaming" in label_str:
|
||||
thinking_branch_to_remove = child
|
||||
break
|
||||
if thinking_branch_to_remove:
|
||||
@@ -821,11 +832,13 @@ To enable tracing, do any one of these:
|
||||
# Find the thinking branch to update (similar to completion logic)
|
||||
thinking_branch_to_update = None
|
||||
|
||||
# Method 1: Use the provided tool_branch if it's a thinking node
|
||||
if tool_branch is not None and "Thinking" in str(tool_branch.label):
|
||||
if tool_branch is not None and (
|
||||
"Thinking" in str(tool_branch.label)
|
||||
or "Streaming" in str(tool_branch.label)
|
||||
):
|
||||
thinking_branch_to_update = tool_branch
|
||||
|
||||
# Method 2: Fallback - search for any thinking node if tool_branch is None or not thinking
|
||||
# Method 2: Fallback - search for any thinking/streaming node if tool_branch is None or not found
|
||||
if thinking_branch_to_update is None:
|
||||
parents = [
|
||||
self.current_lite_agent_branch,
|
||||
@@ -836,7 +849,8 @@ To enable tracing, do any one of these:
|
||||
for parent in parents:
|
||||
if isinstance(parent, Tree):
|
||||
for child in parent.children:
|
||||
if "Thinking" in str(child.label):
|
||||
label_str = str(child.label)
|
||||
if "Thinking" in label_str or "Streaming" in label_str:
|
||||
thinking_branch_to_update = child
|
||||
break
|
||||
if thinking_branch_to_update:
|
||||
@@ -860,6 +874,83 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(error_content, "LLM Error", "red")
|
||||
|
||||
def handle_llm_stream_chunk(
|
||||
self,
|
||||
chunk: str,
|
||||
accumulated_text: str,
|
||||
crew_tree: Tree | None,
|
||||
call_type: Any = None,
|
||||
) -> None:
|
||||
"""Handle LLM stream chunk event - display streaming text in a panel.
|
||||
|
||||
Args:
|
||||
chunk: The new chunk of text received.
|
||||
accumulated_text: All text accumulated so far.
|
||||
crew_tree: The current crew tree for rendering.
|
||||
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||
"""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
self._is_streaming = True
|
||||
self._last_stream_call_type = call_type
|
||||
|
||||
if self._live:
|
||||
self._live.stop()
|
||||
self._live = None
|
||||
|
||||
display_text = accumulated_text
|
||||
max_lines = 20
|
||||
lines = display_text.split("\n")
|
||||
if len(lines) > max_lines:
|
||||
display_text = "\n".join(lines[-max_lines:])
|
||||
display_text = "...\n" + display_text
|
||||
|
||||
content = Text()
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
if call_type == LLMCallType.TOOL_CALL:
|
||||
content.append(display_text, style="yellow")
|
||||
title = "🔧 Tool Arguments"
|
||||
border_style = "yellow"
|
||||
else:
|
||||
content.append(display_text, style="bright_green")
|
||||
title = "✅ Agent Final Answer"
|
||||
border_style = "green"
|
||||
|
||||
streaming_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style=border_style,
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
if not self._streaming_live:
|
||||
self._streaming_live = Live(
|
||||
streaming_panel, console=self.console, refresh_per_second=10
|
||||
)
|
||||
self._streaming_live.start()
|
||||
else:
|
||||
self._streaming_live.update(streaming_panel, refresh=True)
|
||||
|
||||
def handle_llm_stream_completed(self) -> None:
|
||||
"""Handle completion of LLM streaming - stop the streaming live display."""
|
||||
self._is_streaming = False
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
if self._last_stream_call_type == LLMCallType.LLM_CALL:
|
||||
self._just_streamed_final_answer = True
|
||||
else:
|
||||
self._just_streamed_final_answer = False
|
||||
|
||||
self._last_stream_call_type = None
|
||||
|
||||
if self._streaming_live:
|
||||
self._streaming_live.stop()
|
||||
self._streaming_live = None
|
||||
|
||||
def handle_crew_test_started(
|
||||
self, crew_name: str, source_id: str, n_iterations: int
|
||||
) -> Tree | None:
|
||||
@@ -1528,6 +1619,10 @@ To enable tracing, do any one of these:
|
||||
self.print()
|
||||
|
||||
elif isinstance(formatted_answer, AgentFinish):
|
||||
if self._just_streamed_final_answer:
|
||||
self._just_streamed_final_answer = False
|
||||
return
|
||||
|
||||
is_a2a_delegation = False
|
||||
try:
|
||||
output_data = json.loads(formatted_answer.output)
|
||||
@@ -1866,7 +1961,7 @@ To enable tracing, do any one of these:
|
||||
agent_id: str,
|
||||
is_multiturn: bool = False,
|
||||
turn_number: int = 1,
|
||||
) -> None:
|
||||
) -> Tree | None:
|
||||
"""Handle A2A delegation started event.
|
||||
|
||||
Args:
|
||||
@@ -1979,7 +2074,7 @@ To enable tracing, do any one of these:
|
||||
if status == "input_required" and error:
|
||||
pass
|
||||
elif status == "completed":
|
||||
if has_tree:
|
||||
if has_tree and isinstance(self.current_a2a_conversation_branch, Tree):
|
||||
final_turn = self.current_a2a_conversation_branch.add("")
|
||||
self.update_tree_label(
|
||||
final_turn,
|
||||
@@ -1995,7 +2090,7 @@ To enable tracing, do any one of these:
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
elif status == "failed":
|
||||
if has_tree:
|
||||
if has_tree and isinstance(self.current_a2a_conversation_branch, Tree):
|
||||
error_turn = self.current_a2a_conversation_branch.add("")
|
||||
error_msg = (
|
||||
error[:150] + "..." if error and len(error) > 150 else error
|
||||
|
||||
@@ -70,7 +70,16 @@ from crewai.flow.utils import (
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -456,6 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
initial_state: type[T] | T | None = None
|
||||
name: str | None = None
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -822,20 +832,56 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if hasattr(self._state, key):
|
||||
object.__setattr__(self._state, key, value)
|
||||
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
def kickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
|
||||
This method wraps kickoff_async so that all state initialization and event
|
||||
emission is handled in the asynchronous method.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=False
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
def run_flow() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = self.kickoff(inputs=inputs)
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
|
||||
streaming_output = FlowStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_flow, output_holder)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
async def _run_flow() -> Any:
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
return asyncio.run(_run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
async def kickoff_async(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
|
||||
@@ -850,6 +896,41 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_flow() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await self.kickoff_async(inputs=inputs)
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = FlowStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_flow, output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage("flow_inputs", inputs or {})
|
||||
flow_token = attach(ctx)
|
||||
|
||||
@@ -927,6 +1008,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1028,6 +1110,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
@@ -1035,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
params=dumped_params,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1053,13 +1136,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
@@ -1081,6 +1165,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
def _copy_and_serialize_state(self) -> dict[str, Any]:
|
||||
state_copy = self._copy_state()
|
||||
if isinstance(state_copy, BaseModel):
|
||||
try:
|
||||
return state_copy.model_dump(mode="json")
|
||||
except Exception:
|
||||
return state_copy.model_dump()
|
||||
else:
|
||||
return state_copy
|
||||
|
||||
async def _execute_listeners(
|
||||
self, trigger_method: FlowMethodName, result: Any
|
||||
) -> None:
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -40,11 +41,123 @@ if TYPE_CHECKING:
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def _extract_string_literals_from_type_annotation(
|
||||
node: ast.expr,
|
||||
function_globals: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Extract string literals from a type annotation AST node.
|
||||
|
||||
Handles:
|
||||
- Literal["a", "b", "c"]
|
||||
- "a" | "b" | "c" (union of string literals)
|
||||
- Just "a" (single string constant annotation)
|
||||
- Enum types with string values (e.g., class MyEnum(str, Enum))
|
||||
|
||||
Args:
|
||||
node: The AST node representing a type annotation.
|
||||
function_globals: The globals dict from the function, used to resolve Enum types.
|
||||
|
||||
Returns:
|
||||
List of string literals found in the annotation.
|
||||
"""
|
||||
|
||||
strings: list[str] = []
|
||||
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
strings.append(node.value)
|
||||
|
||||
elif isinstance(node, ast.Name) and function_globals:
|
||||
enum_class = function_globals.get(node.id)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value for member in enum_class if isinstance(member.value, str)
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.Attribute) and function_globals:
|
||||
try:
|
||||
if isinstance(node.value, ast.Name):
|
||||
module = function_globals.get(node.value.id)
|
||||
if module is not None:
|
||||
enum_class = getattr(module, node.attr, None)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value
|
||||
for member in enum_class
|
||||
if isinstance(member.value, str)
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
elif isinstance(node, ast.Subscript):
|
||||
is_literal = False
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
|
||||
is_literal = True
|
||||
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
|
||||
is_literal = True
|
||||
|
||||
if is_literal:
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
strings.extend(
|
||||
elt.value
|
||||
for elt in node.slice.elts
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
||||
)
|
||||
elif isinstance(node.slice, ast.Constant) and isinstance(
|
||||
node.slice.value, str
|
||||
):
|
||||
strings.append(node.slice.value)
|
||||
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.left, function_globals)
|
||||
)
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.right, function_globals)
|
||||
)
|
||||
|
||||
return strings
|
||||
|
||||
|
||||
def _unwrap_function(function: Any) -> Any:
|
||||
"""Unwrap a function to get the original function with correct globals.
|
||||
|
||||
Flow methods are wrapped by decorators like @router, @listen, etc.
|
||||
This function unwraps them to get the original function which has
|
||||
the correct __globals__ for resolving type annotations like Enums.
|
||||
|
||||
Args:
|
||||
function: The potentially wrapped function.
|
||||
|
||||
Returns:
|
||||
The unwrapped original function.
|
||||
"""
|
||||
if hasattr(function, "__func__"):
|
||||
function = function.__func__
|
||||
|
||||
if hasattr(function, "__wrapped__"):
|
||||
wrapped = function.__wrapped__
|
||||
if hasattr(wrapped, "unwrap"):
|
||||
return wrapped.unwrap()
|
||||
return wrapped
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
"""Extract possible string return values from a function using AST parsing.
|
||||
|
||||
This function analyzes the source code of a router method to identify
|
||||
all possible string values it might return. It handles:
|
||||
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
|
||||
- Enum type annotations: -> MyEnum (extracts string values from members)
|
||||
- Direct string literals: return "value"
|
||||
- Variable assignments: x = "value"; return x
|
||||
- Dictionary lookups: d = {"k": "v"}; return d[key]
|
||||
@@ -57,6 +170,8 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
Returns:
|
||||
List of possible string return values, or None if analysis fails.
|
||||
"""
|
||||
unwrapped = _unwrap_function(function)
|
||||
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -97,6 +212,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
return None
|
||||
|
||||
return_values: set[str] = set()
|
||||
|
||||
function_globals = getattr(unwrapped, "__globals__", None)
|
||||
|
||||
for node in ast.walk(code_ast):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if node.returns:
|
||||
annotation_values = _extract_string_literals_from_type_annotation(
|
||||
node.returns, function_globals
|
||||
)
|
||||
return_values.update(annotation_values)
|
||||
break # Only process the first function definition
|
||||
dict_definitions: dict[str, list[str]] = {}
|
||||
variable_values: dict[str, list[str]] = {}
|
||||
state_attribute_values: dict[str, list[str]] = {}
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import FlowCondition
|
||||
from crewai.flow.types import FlowMethodName, FlowRouteName
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.flow.utils import (
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
@@ -18,6 +18,9 @@ from crewai.flow.visualization.schema import extract_method_signature
|
||||
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -346,34 +349,43 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
if trigger_method in nodes
|
||||
)
|
||||
|
||||
all_string_triggers: set[str] = set()
|
||||
for condition_data in flow._listeners.values():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
if str(m) not in nodes: # It's a string trigger, not a method name
|
||||
all_string_triggers.add(str(m))
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
for trigger in _extract_direct_or_triggers(condition_data):
|
||||
if trigger not in nodes:
|
||||
all_string_triggers.add(trigger)
|
||||
|
||||
all_router_outputs: set[str] = set()
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = []
|
||||
|
||||
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
|
||||
flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
)
|
||||
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
if current_paths and router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
|
||||
all_router_outputs.update(str(p) for p in current_paths)
|
||||
|
||||
for condition_data in flow._listeners.values():
|
||||
trigger_strings: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
trigger_strings = [str(m) for m in methods]
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
trigger_strings = _extract_direct_or_triggers(condition_data)
|
||||
|
||||
for trigger_str in trigger_strings:
|
||||
if trigger_str not in nodes:
|
||||
# This is likely a router path output
|
||||
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
|
||||
|
||||
if inferred_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = list(
|
||||
inferred_paths # type: ignore[arg-type]
|
||||
if not current_paths:
|
||||
logger.warning(
|
||||
f"Could not determine return paths for router '{router_method_name}'. "
|
||||
f"Add a return type annotation like "
|
||||
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
|
||||
f"to enable proper flow visualization."
|
||||
)
|
||||
if router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = list(inferred_paths)
|
||||
|
||||
orphaned_triggers = all_string_triggers - all_router_outputs
|
||||
if orphaned_triggers:
|
||||
logger.error(
|
||||
f"Found listeners waiting for triggers {orphaned_triggers} "
|
||||
f"but no router outputs these values explicitly. "
|
||||
f"If your router returns a non-static value, check that your router has proper return type annotations."
|
||||
)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
@@ -383,6 +395,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
|
||||
for path in router_paths:
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
if listener_name == router_method_name:
|
||||
continue
|
||||
|
||||
trigger_strings_from_cond: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
|
||||
@@ -179,6 +179,7 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
# gemini
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-2.0-flash": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp-01-21": 32768,
|
||||
"gemini-2.0-flash-lite-001": 1048576,
|
||||
@@ -385,9 +386,10 @@ class LLM(BaseLLM):
|
||||
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
|
||||
try:
|
||||
# Remove 'provider' from kwargs if it exists to avoid duplicate keyword argument
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k != 'provider'}
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
|
||||
return cast(
|
||||
Self, native_class(model=model_string, provider=provider, **kwargs_copy)
|
||||
Self,
|
||||
native_class(model=model_string, provider=provider, **kwargs_copy),
|
||||
)
|
||||
except NotImplementedError:
|
||||
raise
|
||||
@@ -404,46 +406,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -756,6 +812,7 @@ class LLM(BaseLLM):
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
@@ -957,6 +1014,7 @@ class LLM(BaseLLM):
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1695,12 +1753,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -235,6 +239,7 @@ ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
]
|
||||
|
||||
GeminiModels: TypeAlias = Literal[
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
@@ -287,6 +292,7 @@ GeminiModels: TypeAlias = Literal[
|
||||
"learnlm-2.0-flash-experimental",
|
||||
]
|
||||
GEMINI_MODELS: list[GeminiModels] = [
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
@@ -450,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -522,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -26,6 +27,7 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -310,6 +315,14 @@ class AzureCompletion(BaseLLM):
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
params.pop(drop_param, None)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -100,9 +101,8 @@ class GeminiCompletion(BaseLLM):
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -559,6 +559,7 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
context_windows = {
|
||||
"gemini-3-pro-preview": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -245,6 +246,16 @@ class OpenAICompletion(BaseLLM):
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if self.response_format is not None:
|
||||
if isinstance(self.response_format, type) and issubclass(
|
||||
self.response_format, BaseModel
|
||||
):
|
||||
params["response_format"] = generate_model_description(
|
||||
self.response_format
|
||||
)
|
||||
elif isinstance(self.response_format, dict):
|
||||
params["response_format"] = self.response_format
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -303,8 +314,11 @@ class OpenAICompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
**params,
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
@@ -32,16 +33,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
@@ -96,6 +97,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
import boto3
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
"BEDROCK_MODEL_NAME",
|
||||
"AWS_BEDROCK_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
description="Cohere API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
model_name: Literal[
|
||||
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
|
||||
] = Field(
|
||||
default="gemini-embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GEMINI_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
"""Configuration for Google Generative AI provider.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key for authentication.
|
||||
model_name: Embedding model name.
|
||||
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"gemini-embedding-001",
|
||||
]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
description="WatsonX model ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
|
||||
),
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
|
||||
),
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
|
||||
),
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
|
||||
),
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
|
||||
),
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
|
||||
),
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
|
||||
),
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
description="WatsonX API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
description="WatsonX API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
|
||||
),
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
|
||||
),
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
|
||||
),
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
|
||||
),
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
|
||||
),
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
|
||||
),
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
|
||||
),
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
|
||||
),
|
||||
)
|
||||
proxies: dict[str, Any] | None = Field(
|
||||
default=None, description="Proxy configuration"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
"INSTRUCTOR_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
|
||||
),
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
description="Jina API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_JINA_MODEL_NAME",
|
||||
"JINA_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
description="Azure API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
|
||||
),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
default="2024-02-01",
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION",
|
||||
"OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"AZURE_OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
"OPENAI_DIMENSIONS",
|
||||
"AZURE_OPENAI_DIMENSIONS",
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
deployment_id: str = Field(
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
"AZURE_OPENAI_DEPLOYMENT",
|
||||
"AZURE_DEPLOYMENT_ID",
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
"OPENAI_ORGANIZATION_ID",
|
||||
"AZURE_OPENAI_ORGANIZATION_ID",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
deployment_id: Required[str]
|
||||
organization_id: str
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
"OPENCLIP_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
|
||||
),
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
|
||||
),
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
|
||||
),
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
"TEXT2VEC_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
description="Voyage AI API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
|
||||
),
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
|
||||
),
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
|
||||
),
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
|
||||
),
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
|
||||
),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
|
||||
ProviderSpec = (
|
||||
ProviderSpec: TypeAlias = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.models import VectorParams
|
||||
else:
|
||||
VectorParams = Any
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
361
lib/crewai/src/crewai/types/streaming.py
Normal file
361
lib/crewai/src/crewai/types/streaming.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Streaming output types for crew and flow execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StreamChunkType(Enum):
|
||||
"""Type of streaming chunk."""
|
||||
|
||||
TEXT = "text"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
|
||||
class ToolCallChunk(BaseModel):
|
||||
"""Tool call information in a streaming chunk.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier for the tool call
|
||||
tool_name: Name of the tool being called
|
||||
arguments: JSON string of tool arguments
|
||||
index: Index of the tool call in the response
|
||||
"""
|
||||
|
||||
tool_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
arguments: str = ""
|
||||
index: int = 0
|
||||
|
||||
|
||||
class StreamChunk(BaseModel):
|
||||
"""Base streaming chunk with full context.
|
||||
|
||||
Attributes:
|
||||
content: The streaming content (text or partial content)
|
||||
chunk_type: Type of the chunk (text, tool_call, etc.)
|
||||
task_index: Index of the current task (0-based)
|
||||
task_name: Name or description of the current task
|
||||
task_id: Unique identifier of the task
|
||||
agent_role: Role of the agent executing the task
|
||||
agent_id: Unique identifier of the agent
|
||||
tool_call: Tool call information if chunk_type is TOOL_CALL
|
||||
"""
|
||||
|
||||
content: str = Field(description="The streaming content")
|
||||
chunk_type: StreamChunkType = Field(
|
||||
default=StreamChunkType.TEXT, description="Type of the chunk"
|
||||
)
|
||||
task_index: int = Field(default=0, description="Index of the current task")
|
||||
task_name: str = Field(default="", description="Name of the current task")
|
||||
task_id: str = Field(default="", description="Unique identifier of the task")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
agent_id: str = Field(default="", description="Unique identifier of the agent")
|
||||
tool_call: ToolCallChunk | None = Field(
|
||||
default=None, description="Tool call information"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the chunk content as a string."""
|
||||
return self.content
|
||||
|
||||
|
||||
class StreamingOutputBase(Generic[T]):
|
||||
"""Base class for streaming output with result access.
|
||||
|
||||
Provides iteration over stream chunks and access to final result
|
||||
via the .result property after streaming completes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize streaming output base."""
|
||||
self._result: T | None = None
|
||||
self._completed: bool = False
|
||||
self._chunks: list[StreamChunk] = []
|
||||
self._error: Exception | None = None
|
||||
|
||||
@property
|
||||
def result(self) -> T:
|
||||
"""Get the final result after streaming completes.
|
||||
|
||||
Returns:
|
||||
The final output (CrewOutput for crews, Any for flows).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If streaming has not completed yet.
|
||||
Exception: If streaming failed with an error.
|
||||
"""
|
||||
if not self._completed:
|
||||
raise RuntimeError(
|
||||
"Streaming has not completed yet. "
|
||||
"Iterate over all chunks before accessing result."
|
||||
)
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
if self._result is None:
|
||||
raise RuntimeError("No result available")
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if streaming has completed."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[StreamChunk]:
|
||||
"""Get all collected chunks so far."""
|
||||
return self._chunks.copy()
|
||||
|
||||
def get_full_text(self) -> str:
|
||||
"""Get all streamed text content concatenated.
|
||||
|
||||
Returns:
|
||||
All text chunks concatenated together.
|
||||
"""
|
||||
return "".join(
|
||||
chunk.content
|
||||
for chunk in self._chunks
|
||||
if chunk.chunk_type == StreamChunkType.TEXT
|
||||
)
|
||||
|
||||
|
||||
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
"""Streaming output wrapper for crew execution.
|
||||
|
||||
Provides both sync and async iteration over stream chunks,
|
||||
with access to the final CrewOutput via the .result property.
|
||||
|
||||
For kickoff_for_each_async with streaming, use .results to get list of outputs.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Single crew
|
||||
streaming = crew.kickoff(inputs={"topic": "AI"})
|
||||
for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
|
||||
# Multiple crews (kickoff_for_each_async)
|
||||
streaming = await crew.kickoff_for_each_async(
|
||||
[{"topic": "AI"}, {"topic": "ML"}]
|
||||
)
|
||||
async for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
results = streaming.results # List of CrewOutput
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize crew streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
self._results: list[CrewOutput] | None = None
|
||||
|
||||
@property
|
||||
def results(self) -> list[CrewOutput]:
|
||||
"""Get all results for kickoff_for_each_async.
|
||||
|
||||
Returns:
|
||||
List of CrewOutput from all crews.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If streaming has not completed or results not available.
|
||||
"""
|
||||
if not self._completed:
|
||||
raise RuntimeError(
|
||||
"Streaming has not completed yet. "
|
||||
"Iterate over all chunks before accessing results."
|
||||
)
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
if self._results is not None:
|
||||
return self._results
|
||||
if self._result is not None:
|
||||
return [self._result]
|
||||
raise RuntimeError("No results available")
|
||||
|
||||
def _set_results(self, results: list[CrewOutput]) -> None:
|
||||
"""Set multiple results for kickoff_for_each_async.
|
||||
|
||||
Args:
|
||||
results: List of CrewOutput from all crews.
|
||||
"""
|
||||
self._results = results
|
||||
self._completed = True
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: CrewOutput) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
Args:
|
||||
result: The final CrewOutput.
|
||||
"""
|
||||
self._result = result
|
||||
self._completed = True
|
||||
|
||||
|
||||
class FlowStreamingOutput(StreamingOutputBase[Any]):
|
||||
"""Streaming output wrapper for flow execution.
|
||||
|
||||
Provides both sync and async iteration over stream chunks,
|
||||
with access to the final flow output via the .result property.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Sync usage
|
||||
streaming = flow.kickoff_streaming()
|
||||
for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
|
||||
# Async usage
|
||||
streaming = await flow.kickoff_streaming_async()
|
||||
async for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize flow streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: Any) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
Args:
|
||||
result: The final flow output.
|
||||
"""
|
||||
self._result = result
|
||||
self._completed = True
|
||||
296
lib/crewai/src/crewai/utilities/streaming.py
Normal file
296
lib/crewai/src/crewai/utilities/streaming.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Streaming utilities for crew and flow execution."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.types.streaming import (
|
||||
CrewStreamingOutput,
|
||||
FlowStreamingOutput,
|
||||
StreamChunk,
|
||||
StreamChunkType,
|
||||
ToolCallChunk,
|
||||
)
|
||||
|
||||
|
||||
class TaskInfo(TypedDict):
|
||||
"""Task context information for streaming."""
|
||||
|
||||
index: int
|
||||
name: str
|
||||
id: str
|
||||
agent_role: str
|
||||
agent_id: str
|
||||
|
||||
|
||||
class StreamingState(NamedTuple):
|
||||
"""Immutable state for streaming execution."""
|
||||
|
||||
current_task_info: TaskInfo
|
||||
result_holder: list[Any]
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception]
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
handler: Callable[[Any, BaseEvent], None]
|
||||
|
||||
|
||||
def _extract_tool_call_info(
|
||||
event: LLMStreamChunkEvent,
|
||||
) -> tuple[StreamChunkType, ToolCallChunk | None]:
|
||||
"""Extract tool call information from an LLM stream chunk event.
|
||||
|
||||
Args:
|
||||
event: The LLM stream chunk event to process.
|
||||
|
||||
Returns:
|
||||
A tuple of (chunk_type, tool_call_chunk) where tool_call_chunk is None
|
||||
if the event is not a tool call.
|
||||
"""
|
||||
if event.tool_call:
|
||||
return (
|
||||
StreamChunkType.TOOL_CALL,
|
||||
ToolCallChunk(
|
||||
tool_id=event.tool_call.id,
|
||||
tool_name=event.tool_call.function.name,
|
||||
arguments=event.tool_call.function.arguments,
|
||||
index=event.tool_call.index,
|
||||
),
|
||||
)
|
||||
return StreamChunkType.TEXT, None
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
event: LLMStreamChunkEvent,
|
||||
current_task_info: TaskInfo,
|
||||
) -> StreamChunk:
|
||||
"""Create a StreamChunk from an LLM stream chunk event.
|
||||
|
||||
Args:
|
||||
event: The LLM stream chunk event to process.
|
||||
current_task_info: Task context info.
|
||||
|
||||
Returns:
|
||||
A StreamChunk populated with event and task info.
|
||||
"""
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
return StreamChunk(
|
||||
content=event.chunk,
|
||||
chunk_type=chunk_type,
|
||||
task_index=current_task_info["index"],
|
||||
task_name=current_task_info["name"],
|
||||
task_id=current_task_info["id"],
|
||||
agent_role=event.agent_role or current_task_info["agent_role"],
|
||||
agent_id=event.agent_id or current_task_info["agent_id"],
|
||||
tool_call=tool_call_chunk,
|
||||
)
|
||||
|
||||
|
||||
def _create_stream_handler(
|
||||
current_task_info: TaskInfo,
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception],
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> Callable[[Any, BaseEvent], None]:
|
||||
"""Create a stream handler function.
|
||||
|
||||
Args:
|
||||
current_task_info: Task context info.
|
||||
sync_queue: Synchronous queue for chunks.
|
||||
async_queue: Optional async queue for chunks.
|
||||
loop: Optional event loop for async operations.
|
||||
|
||||
Returns:
|
||||
Handler function that can be registered with the event bus.
|
||||
"""
|
||||
|
||||
def stream_handler(_: Any, event: BaseEvent) -> None:
|
||||
"""Handle LLM stream chunk events and enqueue them.
|
||||
|
||||
Args:
|
||||
_: Event source (unused).
|
||||
event: The event to process.
|
||||
"""
|
||||
if not isinstance(event, LLMStreamChunkEvent):
|
||||
return
|
||||
|
||||
chunk = _create_stream_chunk(event, current_task_info)
|
||||
|
||||
if async_queue is not None and loop is not None:
|
||||
loop.call_soon_threadsafe(async_queue.put_nowait, chunk)
|
||||
else:
|
||||
sync_queue.put(chunk)
|
||||
|
||||
return stream_handler
|
||||
|
||||
|
||||
def _unregister_handler(handler: Callable[[Any, BaseEvent], None]) -> None:
|
||||
"""Unregister a stream handler from the event bus.
|
||||
|
||||
Args:
|
||||
handler: The handler function to unregister.
|
||||
"""
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
handlers: frozenset[Callable[[Any, BaseEvent], None]] = (
|
||||
crewai_event_bus._sync_handlers.get(LLMStreamChunkEvent, frozenset())
|
||||
)
|
||||
crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler}
|
||||
|
||||
|
||||
def _finalize_streaming(
|
||||
state: StreamingState,
|
||||
streaming_output: CrewStreamingOutput | FlowStreamingOutput,
|
||||
) -> None:
|
||||
"""Finalize streaming by unregistering handler and setting result.
|
||||
|
||||
Args:
|
||||
state: The streaming state to finalize.
|
||||
streaming_output: The streaming output to set the result on.
|
||||
"""
|
||||
_unregister_handler(state.handler)
|
||||
if state.result_holder:
|
||||
streaming_output._set_result(state.result_holder[0])
|
||||
|
||||
|
||||
def create_streaming_state(
|
||||
current_task_info: TaskInfo,
|
||||
result_holder: list[Any],
|
||||
use_async: bool = False,
|
||||
) -> StreamingState:
|
||||
"""Create and register streaming state.
|
||||
|
||||
Args:
|
||||
current_task_info: Task context info.
|
||||
result_holder: List to hold the final result.
|
||||
use_async: Whether to use async queue.
|
||||
|
||||
Returns:
|
||||
Initialized StreamingState with registered handler.
|
||||
"""
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception] = queue.Queue()
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
if use_async:
|
||||
async_queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
handler = _create_stream_handler(current_task_info, sync_queue, async_queue, loop)
|
||||
crewai_event_bus.register_handler(LLMStreamChunkEvent, handler)
|
||||
|
||||
return StreamingState(
|
||||
current_task_info=current_task_info,
|
||||
result_holder=result_holder,
|
||||
sync_queue=sync_queue,
|
||||
async_queue=async_queue,
|
||||
loop=loop,
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
|
||||
def signal_end(state: StreamingState, is_async: bool = False) -> None:
|
||||
"""Signal end of stream.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
is_async: Whether this is an async stream.
|
||||
"""
|
||||
if is_async and state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(state.async_queue.put_nowait, None)
|
||||
else:
|
||||
state.sync_queue.put(None)
|
||||
|
||||
|
||||
def signal_error(
|
||||
state: StreamingState, error: Exception, is_async: bool = False
|
||||
) -> None:
|
||||
"""Signal an error in the stream.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
error: The exception to signal.
|
||||
is_async: Whether this is an async stream.
|
||||
"""
|
||||
if is_async and state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(state.async_queue.put_nowait, error)
|
||||
else:
|
||||
state.sync_queue.put(error)
|
||||
|
||||
|
||||
def create_chunk_generator(
|
||||
state: StreamingState,
|
||||
run_func: Callable[[], None],
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput],
|
||||
) -> Iterator[StreamChunk]:
|
||||
"""Create a chunk generator that uses a holder to access streaming output.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
run_func: Function to run in a separate thread.
|
||||
output_holder: Single-element list that will contain the streaming output.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
thread = threading.Thread(target=run_func, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = state.sync_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
thread.join()
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
_unregister_handler(state.handler)
|
||||
|
||||
|
||||
async def create_async_chunk_generator(
|
||||
state: StreamingState,
|
||||
run_coro: Callable[[], Any],
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput],
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""Create an async chunk generator that uses a holder to access streaming output.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
run_coro: Coroutine function to run as a task.
|
||||
output_holder: Single-element list that will contain the streaming output.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
if state.async_queue is None:
|
||||
raise RuntimeError(
|
||||
"Async queue not initialized. Use create_streaming_state(use_async=True)."
|
||||
)
|
||||
|
||||
task = asyncio.create_task(run_coro())
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await state.async_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
await task
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
_unregister_handler(state.handler)
|
||||
@@ -307,27 +307,22 @@ def test_cache_hitting():
|
||||
event_handled = True
|
||||
condition.notify()
|
||||
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
):
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
agent=agent,
|
||||
expected_output="The number that is the result of the multiplication tool.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "0"
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input='{"first_number": 2, "second_number": 6}'
|
||||
)
|
||||
with condition:
|
||||
if not event_handled:
|
||||
condition.wait(timeout=5)
|
||||
assert event_handled, "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
task = Task(
|
||||
description="What is 2 times 6? Return only the result of the multiplication.",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "12"
|
||||
|
||||
with condition:
|
||||
if not event_handled:
|
||||
condition.wait(timeout=5)
|
||||
assert event_handled, "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
assert received_events[0].output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,22 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nUse the following
|
||||
format:\n\nThought: you should always think about what to do\nAction: the action
|
||||
to take, only one name of [dummy_tool], just the name, exactly as it''s written.\nAction
|
||||
Input: the input to the action, just a simple python dictionary, enclosed in
|
||||
curly braces, using \" to wrap keys and values.\nObservation: the result of
|
||||
the action\n\nOnce all necessary information is gathered:\n\nThought: I now
|
||||
know the final answer\nFinal Answer: the final answer to the original input
|
||||
question"}, {"role": "user", "content": "\nCurrent Task: Use the dummy tool
|
||||
to get a result for ''test query''\n\nThis is the expect criteria for your final
|
||||
answer: The result from the dummy tool\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}], "model": "gpt-3.5-turbo", "stop": ["\nObservation:"],
|
||||
"stream": false}'
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
@@ -26,13 +25,13 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1363'
|
||||
- '1381'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.52.1
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
@@ -42,35 +41,33 @@ interactions:
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.52.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AmjTkjHtNtJfKGo6wS35grXEzfoqv\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1736177928,\n \"model\": \"gpt-3.5-turbo-0125\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"I should use the dummy tool to get a
|
||||
result for the 'test query'.\\n\\nAction: dummy_tool\\nAction Input: {\\\"query\\\":
|
||||
\\\"test query\\\"}\",\n \"refusal\": null\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
271,\n \"completion_tokens\": 31,\n \"total_tokens\": 302,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
|
||||
null\n}\n"
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xTwW4TMRC95ytGvvSSVGlDWthbqYSIECAQSFRstXK8s7tuvR5jj5uGKv+O7CTd
|
||||
FAriYtnz5j2/8YwfRgBC16IAoTrJqndmctl8ff3tJsxWd29vLu/7d1eXnz4vfq7cVft+1ohxYtDy
|
||||
BhXvWceKemeQNdktrDxKxqR6cn72YjqdzU/mGeipRpNorePJ7Hg+4eiXNJmenM53zI60wiAK+D4C
|
||||
AHjIa/Joa7wXBUzH+0iPIcgWRfGYBCA8mRQRMgQdWFoW4wFUZBlttr2A0FE0NcSAwB1CHft+XTGR
|
||||
ASZokUGCxxANQ0M+pxwxBoYfEf366Li0FyoVXBww9zFYWBe5gIdS5OxS5H2NQXntUkaKfCCLYygF
|
||||
rx2mcykC+1JsNqX9uAzo7+RW/8veHWR3nQzgkaO3WIPcIf92WtovHcW24wIWYGkFt2lJiY220oC0
|
||||
YYW+tG/y6SKftvfudT31wytlH4fv6rGJQaa+2mjMASCtJc5l5I5e75DNYw8Ntc7TMvxGFY22OnSV
|
||||
RxnIpn4FJicyuhkBXOdZiU/aL5yn3nHFdIv5utOXr7Z6YhjPAT2f7UAmlmaIz85Ox8/oVTWy1CYc
|
||||
TJtQUnVYD9RhNGWsNR0Ao4Oq/3TznPa2cm3b/5EfAKXQMdaV81hr9bTiIc1j+r1/S3t85WxYpEnU
|
||||
CivW6FMnamxkNNt/JcI6MPZVo22L3nmdP1fq5Ggz+gUAAP//AwDDsh2ZWwQAAA==
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8fdccc13af387bb2-ATL
|
||||
- 9a3a73adce2d43c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -78,15 +75,17 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 06 Jan 2025 15:38:48 GMT
|
||||
- Mon, 24 Nov 2025 16:58:36 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=PdbRW9vzO7559czIqn0xmXQjbN8_vV_J7k1DlkB4d_Y-1736177928-1.0.1.1-7yNcyljwqHI.TVflr9ZnkS705G.K5hgPbHpxRzcO3ZMFi5lHCBPs_KB5pFE043wYzPmDIHpn6fu6jIY9mlNoLQ;
|
||||
path=/; expires=Mon, 06-Jan-25 16:08:48 GMT; domain=.api.openai.com; HttpOnly;
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
path=/; expires=Mon, 24-Nov-25 17:28:36 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=lOOz0FbrrPaRb4IFEeHNcj7QghHzxI1tTV2N0jD9icA-1736177928767-0.0.1.1-604800000;
|
||||
- _cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
@@ -95,14 +94,20 @@ interactions:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '444'
|
||||
- '1413'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '1606'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -110,36 +115,52 @@ interactions:
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999686'
|
||||
- '49999684'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_5b3e93f5d4e6ab8feef83dc26b6eb623
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nUse the following
|
||||
format:\n\nThought: you should always think about what to do\nAction: the action
|
||||
to take, only one name of [dummy_tool], just the name, exactly as it''s written.\nAction
|
||||
Input: the input to the action, just a simple python dictionary, enclosed in
|
||||
curly braces, using \" to wrap keys and values.\nObservation: the result of
|
||||
the action\n\nOnce all necessary information is gathered:\n\nThought: I now
|
||||
know the final answer\nFinal Answer: the final answer to the original input
|
||||
question"}, {"role": "user", "content": "\nCurrent Task: Use the dummy tool
|
||||
to get a result for ''test query''\n\nThis is the expect criteria for your final
|
||||
answer: The result from the dummy tool\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}, {"role": "assistant", "content": "I should use the dummy
|
||||
tool to get a result for the ''test query''.\n\nAction: dummy_tool\nAction Input:
|
||||
{\"query\": \"test query\"}\nObservation: Dummy result for: test query"}], "model":
|
||||
"gpt-3.5-turbo", "stop": ["\nObservation:"], "stream": false}'
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"},{"role":"assistant","content":"I should
|
||||
use the dummy_tool to get a result for the ''test query''.\nAction: dummy_tool\nAction
|
||||
Input: {\"query\": {\"description\": None, \"type\": \"str\"}}\nObservation:
|
||||
\nI encountered an error while trying to use the tool. This was the error: Arguments
|
||||
validation failed: 1 validation error for Dummy_Tool\nquery\n Input should
|
||||
be a valid string [type=string_type, input_value={''description'': ''None'',
|
||||
''type'': ''str''}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.12/v/string_type.\n
|
||||
Tool dummy_tool accepts these inputs: Tool Name: dummy_tool\nTool Arguments:
|
||||
{''query'': {''description'': None, ''type'': ''str''}}\nTool Description: Useful
|
||||
for when you need to get a dummy result for a query..\nMoving on then. I MUST
|
||||
either use a tool (use one at time) OR give my best final answer not both at
|
||||
the same time. When responding, I must use the following format:\n\n```\nThought:
|
||||
you should always think about what to do\nAction: the action to take, should
|
||||
be one of [dummy_tool]\nAction Input: the input to the action, dictionary enclosed
|
||||
in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action
|
||||
Input/Result can repeat N times. Once I know the final answer, I must return
|
||||
the following format:\n\n```\nThought: I now can give a great answer\nFinal
|
||||
Answer: Your final answer must be the great and the most complete as possible,
|
||||
it must be outcome described\n\n```"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
@@ -148,16 +169,16 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1574'
|
||||
- '2841'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=PdbRW9vzO7559czIqn0xmXQjbN8_vV_J7k1DlkB4d_Y-1736177928-1.0.1.1-7yNcyljwqHI.TVflr9ZnkS705G.K5hgPbHpxRzcO3ZMFi5lHCBPs_KB5pFE043wYzPmDIHpn6fu6jIY9mlNoLQ;
|
||||
_cfuvid=lOOz0FbrrPaRb4IFEeHNcj7QghHzxI1tTV2N0jD9icA-1736177928767-0.0.1.1-604800000
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
_cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.52.1
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
@@ -167,34 +188,34 @@ interactions:
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.52.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AmjTkjtDnt98YQ3k4y71C523EQM9p\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1736177928,\n \"model\": \"gpt-3.5-turbo-0125\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Final Answer: Dummy result for: test
|
||||
query\",\n \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\":
|
||||
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 315,\n \"completion_tokens\":
|
||||
9,\n \"total_tokens\": 324,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
|
||||
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
|
||||
null\n}\n"
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//pFPbahsxEH33Vwx6yYtt7LhO0n1LWgomlFKaFko3LLJ2dletdrSRRklN
|
||||
8L8HyZdd9wKFvgikM2cuOmeeRwBClyIDoRrJqu3M5E31+UaeL+ct335c3Ty8/frFLW5vF6G9dNfv
|
||||
xTgy7Po7Kj6wpsq2nUHWlnawcigZY9b55cWr2WyxnF8loLUlmkirO54spssJB7e2k9n8fLlnNlYr
|
||||
9CKDbyMAgOd0xh6pxJ8ig9n48NKi97JGkR2DAISzJr4I6b32LInFuAeVJUZKbd81NtQNZ7CCJ20M
|
||||
KOscKgZuEDR1gaGyrpUMkkpgt4HgNdUJLkPbbgq21oCspaZpTtcqzp4NoMMbrGKyDJ5z8RDQbXKR
|
||||
QS4YPcP+vs3pw9qje5S7HDndNQgOfTAMlbNtXxRSUe0z+BSUQu+rYMwG7JqlJixB7sMOZOsS96wv
|
||||
dzbNKRY4Dk/2CZQkqPUjgoQ6CgeS/BO6nN5pkgau0+0/ag4lcFgFL6MFKBgzACSR5fQFSfz7PbI9
|
||||
ym1s3Tm79r9QRaVJ+6ZwKL2lKK1n24mEbkcA98lW4cQponO27bhg+wNTuYvzva1E7+Qevbzag2xZ
|
||||
mgHr9QE4yVeUyFIbPzCmUFI1WPbU3sUylNoOgNFg6t+7+VPu3eSa6n9J3wNKYcdYFp3DUqvTifsw
|
||||
h3HR/xZ2/OXUsIgu1goL1uiiEiVWMpjdCgq/8YxtUWmq0XVOpz2MSo62oxcAAAD//wMA+UmELoYE
|
||||
AAA=
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8fdccc171b647bb2-ATL
|
||||
- 9a3a73bbf9d943c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -202,9 +223,11 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 06 Jan 2025 15:38:49 GMT
|
||||
- Mon, 24 Nov 2025 16:58:39 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
@@ -213,14 +236,20 @@ interactions:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '249'
|
||||
- '1513'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '1753'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -228,103 +257,156 @@ interactions:
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999643'
|
||||
- '49999334'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_cdc7b25a3877bb9a7cb7c6d2645ff447
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"trace_id": "1581aff1-2567-43f4-a1f2-a2816533eb7d", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "Unknown Crew", "flow_name": null, "crewai_version": "0.201.1",
|
||||
"privacy_level": "standard"}, "execution_metadata": {"expected_duration_estimate":
|
||||
300, "agent_count": 0, "task_count": 0, "flow_method_count": 0, "execution_started_at":
|
||||
"2025-10-08T18:11:28.008595+00:00"}}'
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"},{"role":"assistant","content":"I should
|
||||
use the dummy_tool to get a result for the ''test query''.\nAction: dummy_tool\nAction
|
||||
Input: {\"query\": {\"description\": None, \"type\": \"str\"}}\nObservation:
|
||||
\nI encountered an error while trying to use the tool. This was the error: Arguments
|
||||
validation failed: 1 validation error for Dummy_Tool\nquery\n Input should
|
||||
be a valid string [type=string_type, input_value={''description'': ''None'',
|
||||
''type'': ''str''}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.12/v/string_type.\n
|
||||
Tool dummy_tool accepts these inputs: Tool Name: dummy_tool\nTool Arguments:
|
||||
{''query'': {''description'': None, ''type'': ''str''}}\nTool Description: Useful
|
||||
for when you need to get a dummy result for a query..\nMoving on then. I MUST
|
||||
either use a tool (use one at time) OR give my best final answer not both at
|
||||
the same time. When responding, I must use the following format:\n\n```\nThought:
|
||||
you should always think about what to do\nAction: the action to take, should
|
||||
be one of [dummy_tool]\nAction Input: the input to the action, dictionary enclosed
|
||||
in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action
|
||||
Input/Result can repeat N times. Once I know the final answer, I must return
|
||||
the following format:\n\n```\nThought: I now can give a great answer\nFinal
|
||||
Answer: Your final answer must be the great and the most complete as possible,
|
||||
it must be outcome described\n\n```"},{"role":"assistant","content":"Thought:
|
||||
I will correct the input format and try using the dummy_tool again.\nAction:
|
||||
dummy_tool\nAction Input: {\"query\": \"test query\"}\nObservation: Dummy result
|
||||
for: test query"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate, zstd
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '436'
|
||||
Content-Type:
|
||||
accept:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/0.201.1
|
||||
X-Crewai-Organization-Id:
|
||||
- d3a3d10c-35db-423f-a7a4-c026030ba64d
|
||||
X-Crewai-Version:
|
||||
- 0.201.1
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '3057'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
_cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: '{"id":"30844ebe-8ac6-4f67-939a-7a072d792654","trace_id":"1581aff1-2567-43f4-a1f2-a2816533eb7d","execution_type":"crew","crew_name":"Unknown
|
||||
Crew","flow_name":null,"status":"running","duration_ms":null,"crewai_version":"0.201.1","privacy_level":"standard","total_events":0,"execution_context":{"crew_fingerprint":null,"crew_name":"Unknown
|
||||
Crew","flow_name":null,"crewai_version":"0.201.1","privacy_level":"standard"},"created_at":"2025-10-08T18:11:28.353Z","updated_at":"2025-10-08T18:11:28.353Z"}'
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLBbhMxEL3vV4x8TqqkTULZWwFFAq4gpEK18npnd028HmOPW6Iq/47s
|
||||
pNktFKkXS/abN37vzTwWAEI3ogSheslqcGb+vv36rt7e0uqzbna0ut18uv8mtxSDrddKzBKD6p+o
|
||||
+Il1oWhwBlmTPcLKo2RMXZdvNqvF4mq9fJuBgRo0idY5nl9drOccfU3zxfJyfWL2pBUGUcL3AgDg
|
||||
MZ9Jo23wtyhhMXt6GTAE2aEoz0UAwpNJL0KGoANLy2I2gooso82yv/QUu55L+AiWHmCXDu4RWm2l
|
||||
AWnDA/ofdptvN/lWwoc4DHvwGKJhaMmXwBgYfkX0++k3HtsYZLJpozETQFpLLFNM2eDdCTmcLRnq
|
||||
nKc6/EUVrbY69JVHGcgm+YHJiYweCoC7HF18loZwngbHFdMO83ebzerYT4zTGtHl9QlkYmkmrOvL
|
||||
2Qv9qgZZahMm4QslVY/NSB0nJWOjaQIUE9f/qnmp99G5tt1r2o+AUugYm8p5bLR67ngs85iW+X9l
|
||||
55SzYBHQ32uFFWv0aRINtjKa45qJsA+MQ9Vq26F3XuddS5MsDsUfAAAA//8DANWDXp9qAwAA
|
||||
headers:
|
||||
Content-Length:
|
||||
- '496'
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
|
||||
https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js
|
||||
https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map
|
||||
https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com
|
||||
https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com
|
||||
https://js-na1.hs-scripts.com https://share.descript.com/; style-src ''self''
|
||||
''unsafe-inline'' *.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts;
|
||||
img-src ''self'' data: *.crewai.com crewai.com https://zeus.tools.crewai.com
|
||||
https://dashboard.tools.crewai.com https://cdn.jsdelivr.net; font-src ''self''
|
||||
data: *.crewai.com crewai.com; connect-src ''self'' *.crewai.com crewai.com
|
||||
https://zeus.tools.crewai.com https://connect.useparagon.com/ https://zeus.useparagon.com/*
|
||||
https://*.useparagon.com/* https://run.pstmn.io https://connect.tools.crewai.com/
|
||||
https://*.sentry.io https://www.google-analytics.com ws://localhost:3036 wss://localhost:3036;
|
||||
frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
|
||||
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
|
||||
https://docs.google.com https://drive.google.com https://slides.google.com
|
||||
https://accounts.google.com https://*.google.com https://www.youtube.com https://share.descript.com'
|
||||
content-type:
|
||||
- application/json; charset=utf-8
|
||||
etag:
|
||||
- W/"a548892c6a8a52833595a42b35b10009"
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
server-timing:
|
||||
- cache_read.active_support;dur=0.05, cache_fetch_hit.active_support;dur=0.00,
|
||||
cache_read_multi.active_support;dur=0.12, start_processing.action_controller;dur=0.00,
|
||||
sql.active_record;dur=30.46, instantiation.active_record;dur=0.38, feature_operation.flipper;dur=0.03,
|
||||
start_transaction.active_record;dur=0.01, transaction.active_record;dur=16.78,
|
||||
process_action.action_controller;dur=309.67
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
CF-RAY:
|
||||
- 9a3a73cd4ff343c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 16:58:40 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '401'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '421'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '50000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999290'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- 7ec132be-e871-4b0a-93f7-81f8d7c0ccae
|
||||
x-runtime:
|
||||
- '0.358533'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 201
|
||||
message: Created
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,69 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"contents":[{"role":"user","parts":[{"text":"What is the capital of France?"}]}],"generationConfig":{"stop_sequences":[]}}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '123'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- generativelanguage.googleapis.com
|
||||
user-agent:
|
||||
- litellm/1.78.5
|
||||
method: POST
|
||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-preview:generateContent
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAC/21UW4+iSBh9719heGxmBgFvbDIPgKAgNwUV3OxDCSWU3KFApdP/fWl77XF2l6RI
|
||||
5ftOnVN1ku+8vQwGhA+yAAUAw5r4Y/BnXxkM3u7/j16eYZjhvvEo9cUCVPgX9vN7e9r3EAyvH4cI
|
||||
J4IDHxQIg2SQnwZyBTIfDlA9eH21QIXq19cfxLd/HY3yJoywjcIM4KaCHzRSvZbEWpL4YIlRytG8
|
||||
a3eoGiukHPHm3jH2FNvMTC1qLlgS05RL42PVyPMdz1uFHpQuytZSBqcHf7PexMHK3mjJQjWKIbM+
|
||||
MxFL6cvWMMfQFsOJ3UQk5j1hWmoxK1DrLqncyrpcQ+UY0uZog2oqkTmXiQ2f27ZBpS58MXBTxRbX
|
||||
qdfsl25Vn5tswrUHeVhVxenW7kaG0cKdt2hjjxPUBYY26BAUvbqqw30AoG0eTMmzdImnIrI51+VY
|
||||
xeqUl/HKs8ZgfBPF0bbtMDjMzxZSkv3KNuJgwTlYMkw9YEyKMcfkRvUmkiPpBqL486niJEuQKtE7
|
||||
XibhpJy1AltrXSrjq+iEucKfK5z43Ci6bTu+VIVuRNecmwRN2gnbqQHH6lQ06eNM5ttpwEjZVOI3
|
||||
umesM9qbcxMySprtbDYXaboQdioPMpuEy3U4VZrM6njN0rAk8Fh3/ON+E58FJPDtxD8upIWTbI/D
|
||||
MrqM7RWj7VWo6kMFUgaj5Dpzsg8bE6GoIc+rJEcnau8qGNnZygGNcRO61nD5sXgyWbUQ+Z4XQhrX
|
||||
3C6UyS2OTHAp2cUJVp0eSZqtyTuTy48XjmW0xLJVYRqYYmSZhatQ45ROKPZiXTZTxiq2ceDPIhii
|
||||
7tBurqtSL7ylp5NRw5FUzJXsLkiRJs1BIi05Oxit51ToBF2oTGOvYTXjfJptR62SVdTB7W5aaJzq
|
||||
nb9adAVFIii3gZE5Qz87C+ViVKa3eJ2f4pyiSzasywoHJA2klNL01IIYX6o55V8n3BUc8vKagLIp
|
||||
d/pRZoatSfor/yx4bAYp/udP4mlc3r/2f/2aIqLKk/vUpHkAkwf8/QEgTihDdbSBoM6zD5jtmNbX
|
||||
EBIoC+C1Lw9fHgJ3aqKpQQh1iEGfFOArD4iiytMCO3kMMzFv7kkx++R6ypX/beO8D4XfOvSI/vYf
|
||||
1nrea6LkOW+eoqh/IkgQvt2zRnKdpzDpBZ5VHza8PLn1yJrfL0gz45d//Pq0cAerGn16FcK0d+87
|
||||
+72/Yb9gi+DlrklUsC7yrIZK8IHbeV4/2Sy/LL9r50a3aquVZ2uPeHl/+RvdmjG6dAUAAA==
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json; charset=UTF-8
|
||||
Date:
|
||||
- Wed, 19 Nov 2025 08:56:53 GMT
|
||||
Server:
|
||||
- scaffolding on HTTPServer2
|
||||
Server-Timing:
|
||||
- gfet4t7; dur=2508
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Origin
|
||||
- X-Origin
|
||||
- Referer
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
X-Frame-Options:
|
||||
- SAMEORIGIN
|
||||
X-XSS-Protection:
|
||||
- '0'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
112
lib/crewai/tests/cassettes/test_openai_response_format_none.yaml
Normal file
112
lib/crewai/tests/cassettes/test_openai_response_format_none.yaml
Normal file
@@ -0,0 +1,112 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hello in one word"}],"model":"gpt-4o"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '81'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jJJNT9wwEIbv+RXunDdVNtoPuteiCoEQJ7SiFYqMPcm6OB7LnvAhtP8d
|
||||
OWE3oYDUiw9+5h2/73heMiHAaNgIUDvJqvU2/1nf1K44vVzePG6vr5cPV2V5/nt7eqG36lcLs6Sg
|
||||
u7+o+KD6rqj1FtmQG7AKKBlT1/l6tSjKYr1a96AljTbJGs/5gvKyKBd5cZIXqzfhjozCCBvxJxNC
|
||||
iJf+TBadxifYiGJ2uGkxRtkgbI5FQkAgm25AxmgiS8cwG6Eix+h612doLX2bwoB1F2Xy5jprJ0A6
|
||||
RyxTtt7W7RvZH41Yanygu/iPFGrjTNxVAWUklx6NTB56us+EuO0Dd+8ygA/Ueq6Y7rF/bl4O7WCc
|
||||
8AgPjImlnWgWs0+aVRpZGhsn8wIl1Q71qByHKzttaAKySeSPXj7rPcQ2rvmf9iNQCj2jrnxAbdT7
|
||||
vGNZwLR+X5UdR9wbhojhwSis2GBI36Cxlp0dNgPic2Rsq9q4BoMPZliP2lfqxwkWSyXna8j22SsA
|
||||
AAD//wMAmJrFFCcDAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18dff8580f53-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:08 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '1096'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1138'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999992'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_670507131d6c455caf0e8cbc30a1a792
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,113 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Return a JSON object with a ''status''
|
||||
field set to ''success''"}],"model":"gpt-4o","response_format":{"type":"json_object"}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '160'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xSwW6cMBC98xXWnJeKkC274Zr01vbUVopKhLxmACfGdj1D1Wi1/14ZNgtpU6kX
|
||||
hObNe7z3mGMiBOgGSgGql6wGb9Lb9r4191/46eaD+/j509f9Lvs2PP+47u/usIdNZLjDIyp+Yb1T
|
||||
bvAGWTs7wyqgZIyqV7tim+XZrng/AYNr0ERa5zndujTP8m2a7dOsOBN7pxUSlOJ7IoQQx+kZLdoG
|
||||
f0Epss3LZEAi2SGUlyUhIDgTJyCJNLG0DJsFVM4y2sn1sbJCVEAseaQKyvg+KoVEFVT2tGYFbEeS
|
||||
0bQdjVkB0lrHMoae/D6ckdPFoXGdD+5Af1Ch1VZTXweU5Gx0Q+w8TOgpEeJhamJ8FQ58cIPnmt0T
|
||||
Tp/L81kOluoX8OaMsWNplvH11eYNsbpBltrQqkhQUvXYLMyldTk22q2AZBX5by9vac+xte3+R34B
|
||||
lELP2NQ+YKPV67zLWsB4l/9au1Q8GQbC8FMrrFljiL+hwVaOZj4ZoGdiHOpW2w6DD3q+m9bXO8TD
|
||||
tmizYg/JKfkNAAD//wMA0CE0wkADAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18d7de3c80dc-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:06 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '424'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '443'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999983'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_71bc4c9f29f843d6b3788b119850dfde
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,116 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"What is the capital of France? Be
|
||||
concise."}],"model":"gpt-4o","response_format":{"type":"json_schema","json_schema":{"name":"AnswerResponse","strict":true,"schema":{"description":"Response
|
||||
model with structured fields.","properties":{"answer":{"description":"The answer
|
||||
to the question","title":"Answer","type":"string"},"confidence":{"description":"Confidence
|
||||
score between 0 and 1","title":"Confidence","type":"number"}},"required":["answer","confidence"],"title":"AnswerResponse","type":"object","additionalProperties":false}}}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '571'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK4g9SwFtyA/pmKCH9pRbUVSBwJArmTVFElyqbWr43wvK
|
||||
jqW0KdALDzs7w5ndPWWMgVZQM5AHEeXgTfHQfemO1f0H8cK73fDp82b/iz6uH4+WnC0hTwz3/A1l
|
||||
fGXdSTd4g1E7e4FlQBExqa5225Kv+W5bTsDgFJpE630sSles+bos+L7g2yvx4LREgpp9zRhj7DS9
|
||||
yaJV+BNqxvPXyoBEokeob02MQXAmVUAQaYrCRshnUDob0U6uTw0ISz8wNFA38CiCpgbyJrV0WqGV
|
||||
2EDN76rqvBQI2I0kkn87GrMAhLUuipR/sv50Rc43s8b1Prhn+oMKnbaaDm1AQc4mYxSdhwk9Z4w9
|
||||
TUMZ3+QEH9zgYxvdEafvqvIiB/MWZnC1uoLRRWEWdb7J35FrFUahDS2mClLIA6qZOq9AjEq7BZAt
|
||||
Qv/t5j3tS3Bt+/+RnwEp0UdUrQ+otHybeG4LmI70X223IU+GgTB81xLbqDGkRSjsxGgu9wP0QhGH
|
||||
ttO2x+CDvhxR51tZ7ZFvpFjtIDtnvwEAAP//AwAvoKedTQMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18cf7fe04253-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:05 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '448'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '465'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999987'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_765510cb1e614ed6a83e665bf7c5a07b
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
@@ -15,6 +15,8 @@ class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
# TODO: these expectations are reading from the actual settings, we should mock them.
|
||||
# E.g. if you change the client_id locally, this test will fail.
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
@@ -53,7 +55,7 @@ class TestAuthenticationCommand:
|
||||
self.auth_command.login()
|
||||
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI AMP...\n", style="bold blue"
|
||||
"Signing in to CrewAI AOP...\n", style="bold blue"
|
||||
)
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
@@ -181,7 +183,7 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)",
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
@@ -234,6 +236,7 @@ class TestAuthenticationCommand:
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
@@ -241,7 +244,7 @@ class TestAuthenticationCommand:
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
@@ -298,7 +301,7 @@ class TestAuthenticationCommand:
|
||||
expected_calls = [
|
||||
call("\nWaiting for authentication... ", style="bold blue", end=""),
|
||||
call("Success!", style="bold green"),
|
||||
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
|
||||
call("\n[bold green]Welcome to CrewAI AOP![/bold green]\n"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
|
||||
@@ -72,7 +72,8 @@ class TestSettings(unittest.TestCase):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
|
||||
@@ -128,8 +128,6 @@ class TestAgentEvaluator:
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_specific_agents_from_crew(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent Eval",
|
||||
goal="Complete test tasks successfully",
|
||||
@@ -145,7 +143,7 @@ class TestAgentEvaluator:
|
||||
|
||||
events = {}
|
||||
results_condition = threading.Condition()
|
||||
results_ready = False
|
||||
completed_event_received = False
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
@@ -158,29 +156,23 @@ class TestAgentEvaluator:
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
nonlocal completed_event_received
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
with results_condition:
|
||||
completed_event_received = True
|
||||
results_condition.notify()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
nonlocal results_ready
|
||||
if event.task and event.task.id == task.id:
|
||||
while not agent_evaluator.get_evaluation_results().get(agent.role):
|
||||
pass
|
||||
with results_condition:
|
||||
results_ready = True
|
||||
results_condition.notify()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
with results_condition:
|
||||
assert results_condition.wait_for(
|
||||
lambda: results_ready, timeout=5
|
||||
), "Timeout waiting for evaluation results"
|
||||
lambda: completed_event_received, timeout=5
|
||||
), "Timeout waiting for evaluation completed event"
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
|
||||
@@ -381,6 +381,7 @@ def test_azure_raises_error_when_endpoint_missing():
|
||||
with pytest.raises(ValueError, match="Azure endpoint is required"):
|
||||
AzureCompletion(model="gpt-4", api_key="test-key")
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
@@ -389,6 +390,8 @@ def test_azure_raises_error_when_api_key_missing():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
|
||||
def test_azure_endpoint_configuration():
|
||||
"""
|
||||
Test that Azure endpoint configuration works with multiple environment variable names
|
||||
@@ -1086,3 +1089,27 @@ def test_azure_mistral_and_other_models():
|
||||
)
|
||||
assert "model" in params
|
||||
assert params["model"] == model_name
|
||||
|
||||
|
||||
def test_azure_completion_params_preparation_with_drop_params():
|
||||
"""
|
||||
Test that completion parameters are properly prepared with drop paramaeters attribute respected
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_ENDPOINT": "https://models.inference.ai.azure.com"
|
||||
}):
|
||||
llm = LLM(
|
||||
model="azure/o4-mini",
|
||||
drop_params=True,
|
||||
additional_drop_params=["stop"],
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
params = llm._prepare_completion_params(messages)
|
||||
|
||||
assert params.get('stop') == None
|
||||
@@ -455,13 +455,11 @@ def test_gemini_model_capabilities():
|
||||
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm_2_0, GeminiCompletion)
|
||||
assert llm_2_0.is_gemini_2 == True
|
||||
assert llm_2_0.supports_tools == True
|
||||
|
||||
# Test Gemini 1.5 model
|
||||
llm_1_5 = LLM(model="google/gemini-1.5-pro")
|
||||
assert isinstance(llm_1_5, GeminiCompletion)
|
||||
assert llm_1_5.is_gemini_1_5 == True
|
||||
assert llm_1_5.supports_tools == True
|
||||
|
||||
|
||||
|
||||
@@ -528,3 +528,50 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
assert "input" not in call_kwargs
|
||||
assert "text_format" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_with_pydantic_model():
|
||||
"""
|
||||
Test that response_format with a Pydantic BaseModel returns structured output.
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class AnswerResponse(BaseModel):
|
||||
"""Response model with structured fields."""
|
||||
|
||||
answer: str = Field(description="The answer to the question")
|
||||
confidence: float = Field(description="Confidence score between 0 and 1")
|
||||
|
||||
llm = LLM(model="gpt-4o", response_format=AnswerResponse)
|
||||
result = llm.call("What is the capital of France? Be concise.")
|
||||
|
||||
assert isinstance(result, AnswerResponse)
|
||||
assert result.answer is not None
|
||||
assert 0 <= result.confidence <= 1
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_with_dict():
|
||||
"""
|
||||
Test that response_format with a dict returns JSON output.
|
||||
"""
|
||||
import json
|
||||
|
||||
llm = LLM(model="gpt-4o", response_format={"type": "json_object"})
|
||||
result = llm.call("Return a JSON object with a 'status' field set to 'success'")
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert "status" in parsed
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_none():
|
||||
"""
|
||||
Test that when response_format is None, the API returns plain text.
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", response_format=None)
|
||||
result = llm.call("Say hello in one word")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
22
lib/crewai/tests/mcp/test_sse_transport.py
Normal file
22
lib/crewai/tests/mcp/test_sse_transport.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Tests for SSE transport."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_does_not_pass_invalid_args():
|
||||
"""Test that SSETransport.connect() doesn't pass invalid args to sse_client.
|
||||
|
||||
The sse_client function does not accept terminate_on_close parameter.
|
||||
"""
|
||||
transport = SSETransport(
|
||||
url="http://localhost:9999/sse",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectionError) as exc_info:
|
||||
await transport.connect()
|
||||
|
||||
assert "unexpected keyword argument" not in str(exc_info.value)
|
||||
364
lib/crewai/tests/rag/embeddings/test_backward_compatibility.py
Normal file
364
lib/crewai/tests/rag/embeddings/test_backward_compatibility.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Tests for backward compatibility of embedding provider configurations."""
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider
|
||||
from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider
|
||||
|
||||
|
||||
class TestGoogleProviderAlias:
|
||||
"""Test that 'google' provider name alias works for backward compatibility."""
|
||||
|
||||
def test_google_alias_in_provider_paths(self):
|
||||
"""Verify 'google' is registered as an alias for google-generativeai."""
|
||||
assert "google" in PROVIDER_PATHS
|
||||
assert "google-generativeai" in PROVIDER_PATHS
|
||||
assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"]
|
||||
|
||||
|
||||
class TestModelKeyBackwardCompatibility:
|
||||
"""Test that 'model' config key works as alias for 'model_name'."""
|
||||
|
||||
def test_openai_provider_accepts_model_key(self):
|
||||
"""Test OpenAI provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_openai_provider_model_name_takes_precedence(self):
|
||||
"""Test that model_name takes precedence when both are provided."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
|
||||
def test_cohere_provider_accepts_model_key(self):
|
||||
"""Test Cohere provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_google_generativeai_provider_accepts_model_key(self):
|
||||
"""Test Google Generative AI provider accepts 'model' as alias."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_google_vertex_provider_accepts_model_key(self):
|
||||
"""Test Google Vertex AI provider accepts 'model' as alias."""
|
||||
provider = VertexAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-004",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-004"
|
||||
|
||||
def test_azure_provider_accepts_model_key(self):
|
||||
"""Test Azure provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_jina_provider_accepts_model_key(self):
|
||||
"""Test Jina provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = JinaProvider(
|
||||
api_key="test-key",
|
||||
model="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ollama_provider_accepts_model_key(self):
|
||||
"""Test Ollama provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OllamaProvider(
|
||||
model="nomic-embed-text",
|
||||
)
|
||||
assert provider.model_name == "nomic-embed-text"
|
||||
|
||||
def test_text2vec_provider_accepts_model_key(self):
|
||||
"""Test Text2Vec provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = Text2VecProvider(
|
||||
model="shibing624/text2vec-base-multilingual",
|
||||
)
|
||||
assert provider.model_name == "shibing624/text2vec-base-multilingual"
|
||||
|
||||
def test_sentence_transformer_provider_accepts_model_key(self):
|
||||
"""Test SentenceTransformer provider accepts 'model' as alias."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model="all-mpnet-base-v2",
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
|
||||
def test_instructor_provider_accepts_model_key(self):
|
||||
"""Test Instructor provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = InstructorProvider(
|
||||
model="hkunlp/instructor-xl",
|
||||
)
|
||||
assert provider.model_name == "hkunlp/instructor-xl"
|
||||
|
||||
def test_openclip_provider_accepts_model_key(self):
|
||||
"""Test OpenCLIP provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenCLIPProvider(
|
||||
model="ViT-B-16",
|
||||
)
|
||||
assert provider.model_name == "ViT-B-16"
|
||||
|
||||
|
||||
class TestTaskTypeConfiguration:
|
||||
"""Test that task_type configuration works correctly."""
|
||||
|
||||
def test_google_provider_accepts_lowercase_task_type(self):
|
||||
"""Test Google provider accepts lowercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_google_provider_accepts_uppercase_task_type(self):
|
||||
"""Test Google provider accepts uppercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
def test_google_provider_default_task_type(self):
|
||||
"""Test Google provider has correct default task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
|
||||
class TestFactoryBackwardCompatibility:
|
||||
"""Test factory function with backward compatible configurations."""
|
||||
|
||||
def test_factory_with_google_alias(self):
|
||||
"""Test factory resolves 'google' to google-generativeai provider."""
|
||||
config = {
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
|
||||
def test_factory_with_model_key_openai(self):
|
||||
"""Test factory passes 'model' config to OpenAI provider."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["model"] == "text-embedding-3-small"
|
||||
|
||||
|
||||
class TestDocumentationCodeSnippets:
|
||||
"""Test code snippets from documentation work correctly."""
|
||||
|
||||
def test_memory_openai_config(self):
|
||||
"""Test OpenAI config from memory.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_memory_openai_config_with_options(self):
|
||||
"""Test OpenAI config with all options from memory.mdx."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="your-openai-api-key",
|
||||
model_name="text-embedding-3-large",
|
||||
dimensions=1536,
|
||||
organization_id="your-org-id",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
assert provider.dimensions == 1536
|
||||
|
||||
def test_memory_azure_config(self):
|
||||
"""Test Azure config from memory.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
api_key="your-azure-key",
|
||||
api_base="https://your-resource.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
model_name="text-embedding-3-small",
|
||||
deployment_id="your-deployment-name",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
assert provider.api_type == "azure"
|
||||
|
||||
def test_memory_google_generativeai_config(self):
|
||||
"""Test Google Generative AI config from memory.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-google-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_memory_cohere_config(self):
|
||||
"""Test Cohere config from memory.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-cohere-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_knowledge_agent_embedder_config(self):
|
||||
"""Test agent embedder config from knowledge.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
model_name="gemini-embedding-001",
|
||||
api_key="your-google-key",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_ragtool_openai_config(self):
|
||||
"""Test RagTool OpenAI config from ragtool.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_ragtool_cohere_config(self):
|
||||
"""Test RagTool Cohere config from ragtool.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_ragtool_ollama_config(self):
|
||||
"""Test RagTool Ollama config from ragtool.mdx documentation."""
|
||||
provider = OllamaProvider(
|
||||
model_name="llama2",
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
)
|
||||
assert provider.model_name == "llama2"
|
||||
|
||||
def test_ragtool_azure_config(self):
|
||||
"""Test RagTool Azure config from ragtool.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
deployment_id="your-deployment-id",
|
||||
api_key="your-api-key",
|
||||
api_base="https://your-resource.openai.azure.com",
|
||||
api_version="2024-02-01",
|
||||
model_name="text-embedding-ada-002",
|
||||
api_type="azure",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
assert provider.deployment_id == "your-deployment-id"
|
||||
|
||||
def test_ragtool_google_generativeai_config(self):
|
||||
"""Test RagTool Google Generative AI config from ragtool.mdx."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
def test_ragtool_jina_config(self):
|
||||
"""Test RagTool Jina config from ragtool.mdx documentation."""
|
||||
provider = JinaProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ragtool_sentence_transformer_config(self):
|
||||
"""Test RagTool SentenceTransformer config from ragtool.mdx."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model_name="all-mpnet-base-v2",
|
||||
device="cuda",
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
assert provider.device == "cuda"
|
||||
assert provider.normalize_embeddings is True
|
||||
|
||||
|
||||
class TestLegacyConfigurationFormats:
|
||||
"""Test legacy configuration formats that should still work."""
|
||||
|
||||
def test_legacy_google_with_model_key(self):
|
||||
"""Test legacy Google config using 'model' instead of 'model_name'."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-005",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-005"
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_legacy_openai_with_model_key(self):
|
||||
"""Test legacy OpenAI config using 'model' instead of 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_legacy_cohere_with_model_key(self):
|
||||
"""Test legacy Cohere config using 'model' instead of 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-multilingual-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-multilingual-v3.0"
|
||||
|
||||
def test_legacy_azure_with_model_key(self):
|
||||
"""Test legacy Azure config using 'model' instead of 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
82
lib/crewai/tests/rag/test_rag_storage_path.py
Normal file
82
lib/crewai/tests/rag/test_rag_storage_path.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for RAGStorage custom path functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path when provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/memory/path"
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_default_path_when_none(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses default path when no custom path is provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=None,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
assert storage.path is None
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path_with_batch_size(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path with batch_size in config."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/batch/path"
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "batch_size": 100},
|
||||
}
|
||||
|
||||
RAGStorage(
|
||||
type="long_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
assert config_arg.batch_size == 100
|
||||
@@ -723,11 +723,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state.sent is False
|
||||
assert received_events[3].state["sent"] is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state.sent is True
|
||||
assert received_events[4].state["sent"] is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
|
||||
@@ -415,4 +415,256 @@ def test_router_paths_not_in_and_conditions():
|
||||
|
||||
assert "step_1" in targets
|
||||
assert "step_3_or" in targets
|
||||
assert "step_2_and" not in targets
|
||||
assert "step_2_and" not in targets
|
||||
|
||||
|
||||
def test_chained_routers_no_self_loops():
|
||||
"""Test that chained routers don't create self-referencing edges.
|
||||
|
||||
This tests the bug where routers with string triggers (like 'auth', 'exp')
|
||||
would incorrectly create edges to themselves when another router outputs
|
||||
those strings.
|
||||
"""
|
||||
|
||||
class ChainedRouterFlow(Flow):
|
||||
"""Flow with multiple chained routers using string outputs."""
|
||||
|
||||
@start()
|
||||
def entrance(self):
|
||||
return "started"
|
||||
|
||||
@router(entrance)
|
||||
def session_in_cache(self):
|
||||
return "exp"
|
||||
|
||||
@router("exp")
|
||||
def check_exp(self):
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def call_ai_auth(self):
|
||||
return "action"
|
||||
|
||||
@listen("action")
|
||||
def forward_to_action(self):
|
||||
return "done"
|
||||
|
||||
@listen("authenticate")
|
||||
def forward_to_authenticate(self):
|
||||
return "need_auth"
|
||||
|
||||
flow = ChainedRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check that no self-loops exist
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# Verify correct connections
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
|
||||
# session_in_cache -> check_exp (via 'exp')
|
||||
exp_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "exp" and edge["source"] == "session_in_cache"
|
||||
]
|
||||
assert len(exp_edges) == 1
|
||||
assert exp_edges[0]["target"] == "check_exp"
|
||||
|
||||
# check_exp -> call_ai_auth (via 'auth')
|
||||
auth_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "auth" and edge["source"] == "check_exp"
|
||||
]
|
||||
assert len(auth_edges) == 1
|
||||
assert auth_edges[0]["target"] == "call_ai_auth"
|
||||
|
||||
# call_ai_auth -> forward_to_action (via 'action')
|
||||
action_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "action" and edge["source"] == "call_ai_auth"
|
||||
]
|
||||
assert len(action_edges) == 1
|
||||
assert action_edges[0]["target"] == "forward_to_action"
|
||||
|
||||
|
||||
def test_routers_with_shared_output_strings():
|
||||
"""Test that routers with shared output strings don't create incorrect edges.
|
||||
|
||||
This tests a scenario where multiple routers can output the same string,
|
||||
ensuring the visualization only creates edges for the router that actually
|
||||
outputs the string, not all routers.
|
||||
"""
|
||||
|
||||
class SharedOutputRouterFlow(Flow):
|
||||
"""Flow where multiple routers can output 'auth'."""
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
return "started"
|
||||
|
||||
@router(start)
|
||||
def router_a(self):
|
||||
# This router can output 'auth' or 'skip'
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def router_b(self):
|
||||
# This router listens to 'auth' but outputs 'done'
|
||||
return "done"
|
||||
|
||||
@listen("done")
|
||||
def finalize(self):
|
||||
return "complete"
|
||||
|
||||
@listen("skip")
|
||||
def handle_skip(self):
|
||||
return "skipped"
|
||||
|
||||
flow = SharedOutputRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check no self-loops
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# router_a should connect to router_b via 'auth'
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
auth_from_a = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_a" and edge["router_path_label"] == "auth"
|
||||
]
|
||||
assert len(auth_from_a) == 1
|
||||
assert auth_from_a[0]["target"] == "router_b"
|
||||
|
||||
# router_b should connect to finalize via 'done'
|
||||
done_from_b = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_b" and edge["router_path_label"] == "done"
|
||||
]
|
||||
assert len(done_from_b) == 1
|
||||
assert done_from_b[0]["target"] == "finalize"
|
||||
|
||||
|
||||
def test_warning_for_router_without_paths(caplog):
|
||||
"""Test that a warning is logged when a router has no determinable paths."""
|
||||
import logging
|
||||
|
||||
class RouterWithoutPathsFlow(Flow):
|
||||
"""Flow with a router that returns a dynamic value."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def dynamic_router(self):
|
||||
# Returns a variable that can't be statically analyzed
|
||||
import random
|
||||
return random.choice(["path_a", "path_b"])
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = RouterWithoutPathsFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that warning was logged for the router
|
||||
assert any(
|
||||
"Could not determine return paths for router 'dynamic_router'" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
# Check that error was logged for orphaned triggers
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_warning_for_orphaned_listeners(caplog):
|
||||
"""Test that an error is logged when listeners wait for triggers no router outputs."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class OrphanedListenerFlow(Flow):
|
||||
"""Flow where a listener waits for a trigger that no router outputs."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def my_router(self) -> Literal["option_a", "option_b"]:
|
||||
return "option_a"
|
||||
|
||||
@listen("option_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("option_c") # This trigger is never output by any router
|
||||
def handle_orphan(self):
|
||||
return "orphan"
|
||||
|
||||
flow = OrphanedListenerFlow()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that error was logged for orphaned trigger
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
and "option_c" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_no_warning_for_properly_typed_router(caplog):
|
||||
"""Test that no warning is logged when router has proper type annotations."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class ProperlyTypedRouterFlow(Flow):
|
||||
"""Flow with properly typed router."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def typed_router(self) -> Literal["path_a", "path_b"]:
|
||||
return "path_a"
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = ProperlyTypedRouterFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
@@ -243,7 +243,11 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -259,6 +263,7 @@ def test_validate_call_params_no_response_format():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gemini/gemini-3-pro-preview",
|
||||
"gemini/gemini-2.0-flash-thinking-exp-01-21",
|
||||
"gemini/gemini-2.0-flash-001",
|
||||
"gemini/gemini-2.0-flash-lite-001",
|
||||
@@ -701,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -750,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -767,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -820,19 +843,36 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
|
||||
@@ -272,6 +272,99 @@ def another_simple_tool():
|
||||
return "Hi!"
|
||||
|
||||
|
||||
class TestAsyncDecoratorSupport:
|
||||
"""Tests for async method support in @agent, @task decorators."""
|
||||
|
||||
def test_async_agent_memoization(self):
|
||||
"""Async agent methods should be properly memoized."""
|
||||
|
||||
class AsyncAgentCrew:
|
||||
call_count = 0
|
||||
|
||||
@agent
|
||||
async def async_agent(self):
|
||||
AsyncAgentCrew.call_count += 1
|
||||
return Agent(
|
||||
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentCrew()
|
||||
first_call = crew.async_agent()
|
||||
second_call = crew.async_agent()
|
||||
|
||||
assert first_call is second_call, "Async agent memoization failed"
|
||||
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||
|
||||
def test_async_task_memoization(self):
|
||||
"""Async task methods should be properly memoized."""
|
||||
|
||||
class AsyncTaskCrew:
|
||||
call_count = 0
|
||||
|
||||
@task
|
||||
async def async_task(self):
|
||||
AsyncTaskCrew.call_count += 1
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskCrew()
|
||||
first_call = crew.async_task()
|
||||
second_call = crew.async_task()
|
||||
|
||||
assert first_call is second_call, "Async task memoization failed"
|
||||
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||
|
||||
def test_async_task_name_inference(self):
|
||||
"""Async task should have name inferred from method name."""
|
||||
|
||||
class AsyncTaskNameCrew:
|
||||
@task
|
||||
async def my_async_task(self):
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskNameCrew()
|
||||
task_instance = crew.my_async_task()
|
||||
|
||||
assert task_instance.name == "my_async_task", (
|
||||
"Async task name not inferred correctly"
|
||||
)
|
||||
|
||||
def test_async_agent_returns_agent_not_coroutine(self):
|
||||
"""Async agent decorator should return Agent, not coroutine."""
|
||||
|
||||
class AsyncAgentTypeCrew:
|
||||
@agent
|
||||
async def typed_async_agent(self):
|
||||
return Agent(
|
||||
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentTypeCrew()
|
||||
result = crew.typed_async_agent()
|
||||
|
||||
assert isinstance(result, Agent), (
|
||||
f"Expected Agent, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_async_task_returns_task_not_coroutine(self):
|
||||
"""Async task decorator should return Task, not coroutine."""
|
||||
|
||||
class AsyncTaskTypeCrew:
|
||||
@task
|
||||
async def typed_async_task(self):
|
||||
return Task(
|
||||
description="Typed Description", expected_output="Typed Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskTypeCrew()
|
||||
result = crew.typed_async_task()
|
||||
|
||||
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
717
lib/crewai/tests/test_streaming.py
Normal file
717
lib/crewai/tests/test_streaming.py
Normal file
@@ -0,0 +1,717 @@
|
||||
"""Tests for streaming output functionality in crews and flows."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent, ToolCall, FunctionCall
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.types.streaming import (
|
||||
CrewStreamingOutput,
|
||||
FlowStreamingOutput,
|
||||
StreamChunk,
|
||||
StreamChunkType,
|
||||
ToolCallChunk,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def researcher() -> Agent:
|
||||
"""Create a researcher agent for testing."""
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Research and analyze topics thoroughly",
|
||||
backstory="You are an expert researcher with deep analytical skills.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(researcher: Agent) -> Task:
|
||||
"""Create a simple task for testing."""
|
||||
return Task(
|
||||
description="Write a brief analysis of AI trends",
|
||||
expected_output="A concise analysis of current AI trends",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_crew(researcher: Agent, simple_task: Task) -> Crew:
|
||||
"""Create a simple crew with one agent and one task."""
|
||||
return Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming_crew(researcher: Agent, simple_task: Task) -> Crew:
|
||||
"""Create a streaming crew with one agent and one task."""
|
||||
return Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
class TestStreamChunk:
|
||||
"""Tests for StreamChunk model."""
|
||||
|
||||
def test_stream_chunk_creation(self) -> None:
|
||||
"""Test creating a basic stream chunk."""
|
||||
chunk = StreamChunk(
|
||||
content="Hello, world!",
|
||||
chunk_type=StreamChunkType.TEXT,
|
||||
task_index=0,
|
||||
task_name="Test Task",
|
||||
task_id="task-123",
|
||||
agent_role="Researcher",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
assert chunk.content == "Hello, world!"
|
||||
assert chunk.chunk_type == StreamChunkType.TEXT
|
||||
assert chunk.task_index == 0
|
||||
assert chunk.task_name == "Test Task"
|
||||
assert str(chunk) == "Hello, world!"
|
||||
|
||||
def test_stream_chunk_with_tool_call(self) -> None:
|
||||
"""Test creating a stream chunk with tool call information."""
|
||||
tool_call = ToolCallChunk(
|
||||
tool_id="call-123",
|
||||
tool_name="search",
|
||||
arguments='{"query": "AI trends"}',
|
||||
index=0,
|
||||
)
|
||||
chunk = StreamChunk(
|
||||
content="",
|
||||
chunk_type=StreamChunkType.TOOL_CALL,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
assert chunk.chunk_type == StreamChunkType.TOOL_CALL
|
||||
assert chunk.tool_call is not None
|
||||
assert chunk.tool_call.tool_name == "search"
|
||||
|
||||
|
||||
class TestCrewStreamingOutput:
|
||||
"""Tests for CrewStreamingOutput functionality."""
|
||||
|
||||
def test_result_before_iteration_raises_error(self) -> None:
|
||||
"""Test that accessing result before iteration raises error."""
|
||||
|
||||
def empty_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=empty_gen())
|
||||
with pytest.raises(RuntimeError, match="Streaming has not completed yet"):
|
||||
_ = streaming.result
|
||||
|
||||
def test_is_completed_property(self) -> None:
|
||||
"""Test the is_completed property."""
|
||||
|
||||
def simple_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=simple_gen())
|
||||
assert streaming.is_completed is False
|
||||
|
||||
list(streaming)
|
||||
assert streaming.is_completed is True
|
||||
|
||||
def test_get_full_text(self) -> None:
|
||||
"""Test getting full text from chunks."""
|
||||
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="Hello ")
|
||||
yield StreamChunk(content="World!")
|
||||
yield StreamChunk(content="", chunk_type=StreamChunkType.TOOL_CALL)
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
list(streaming)
|
||||
assert streaming.get_full_text() == "Hello World!"
|
||||
|
||||
def test_chunks_property(self) -> None:
|
||||
"""Test accessing collected chunks."""
|
||||
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="chunk1")
|
||||
yield StreamChunk(content="chunk2")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
list(streaming)
|
||||
assert len(streaming.chunks) == 2
|
||||
assert streaming.chunks[0].content == "chunk1"
|
||||
|
||||
|
||||
class TestFlowStreamingOutput:
|
||||
"""Tests for FlowStreamingOutput functionality."""
|
||||
|
||||
def test_result_before_iteration_raises_error(self) -> None:
|
||||
"""Test that accessing result before iteration raises error."""
|
||||
|
||||
def empty_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = FlowStreamingOutput(sync_iterator=empty_gen())
|
||||
with pytest.raises(RuntimeError, match="Streaming has not completed yet"):
|
||||
_ = streaming.result
|
||||
|
||||
def test_is_completed_property(self) -> None:
|
||||
"""Test the is_completed property."""
|
||||
|
||||
def simple_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = FlowStreamingOutput(sync_iterator=simple_gen())
|
||||
assert streaming.is_completed is False
|
||||
|
||||
list(streaming)
|
||||
assert streaming.is_completed is True
|
||||
|
||||
|
||||
class TestCrewKickoffStreaming:
|
||||
"""Tests for Crew(stream=True).kickoff() method."""
|
||||
|
||||
def test_kickoff_streaming_returns_streaming_output(self, streaming_crew: Crew) -> None:
|
||||
"""Test that kickoff with stream=True returns CrewStreamingOutput."""
|
||||
with patch.object(Crew, "kickoff") as mock_kickoff:
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Test output"
|
||||
|
||||
def side_effect(*args: Any, **kwargs: Any) -> Any:
|
||||
return mock_output
|
||||
mock_kickoff.side_effect = side_effect
|
||||
|
||||
streaming = streaming_crew.kickoff()
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
|
||||
def test_kickoff_streaming_captures_chunks(self, researcher: Agent, simple_task: Task) -> None:
|
||||
"""Test that streaming captures LLM chunks."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Test output"
|
||||
|
||||
original_kickoff = Crew.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Hello ",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="World!",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = crew.kickoff()
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
chunks = list(streaming)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
contents = [c.content for c in chunks]
|
||||
assert "Hello " in contents
|
||||
assert "World!" in contents
|
||||
|
||||
def test_kickoff_streaming_result_available_after_iteration(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test that result is available after iterating all chunks."""
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Final result"
|
||||
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test chunk")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
|
||||
# Iterate all chunks
|
||||
_ = list(streaming)
|
||||
|
||||
# Simulate what _finalize_streaming does
|
||||
streaming._set_result(mock_output)
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw == "Final result"
|
||||
|
||||
def test_kickoff_streaming_handles_tool_calls(self, researcher: Agent, simple_task: Task) -> None:
|
||||
"""Test that streaming handles tool call chunks correctly."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Test output"
|
||||
|
||||
original_kickoff = Crew.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="",
|
||||
tool_call=ToolCall(
|
||||
id="call-123",
|
||||
function=FunctionCall(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
),
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = crew.kickoff()
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
chunks = list(streaming)
|
||||
|
||||
tool_chunks = [c for c in chunks if c.chunk_type == StreamChunkType.TOOL_CALL]
|
||||
assert len(tool_chunks) >= 1
|
||||
assert tool_chunks[0].tool_call is not None
|
||||
assert tool_chunks[0].tool_call.tool_name == "search"
|
||||
|
||||
|
||||
class TestCrewKickoffStreamingAsync:
|
||||
"""Tests for Crew(stream=True).kickoff_async() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_streaming_async_returns_streaming_output(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test that kickoff_async with stream=True returns CrewStreamingOutput."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Test output"
|
||||
|
||||
original_kickoff = Crew.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
return mock_output
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = await crew.kickoff_async()
|
||||
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_streaming_async_captures_chunks(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test that async streaming captures LLM chunks."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Test output"
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Async ",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Stream!",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = await crew.kickoff_async()
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
chunks: list[StreamChunk] = []
|
||||
async for chunk in streaming:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
contents = [c.content for c in chunks]
|
||||
assert "Async " in contents
|
||||
assert "Stream!" in contents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_streaming_async_result_available_after_iteration(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test that result is available after async iteration."""
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Async result"
|
||||
|
||||
async def async_gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="test chunk")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=async_gen())
|
||||
|
||||
# Iterate all chunks
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
# Simulate what _finalize_streaming does
|
||||
streaming._set_result(mock_output)
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw == "Async result"
|
||||
|
||||
|
||||
class TestFlowKickoffStreaming:
|
||||
"""Tests for Flow(stream=True).kickoff() method."""
|
||||
|
||||
def test_kickoff_streaming_returns_streaming_output(self) -> None:
|
||||
"""Test that flow kickoff with stream=True returns FlowStreamingOutput."""
|
||||
|
||||
class SimpleFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
def generate(self) -> str:
|
||||
return "result"
|
||||
|
||||
flow = SimpleFlow()
|
||||
flow.stream = True
|
||||
streaming = flow.kickoff()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
|
||||
def test_flow_kickoff_streaming_captures_chunks(self) -> None:
|
||||
"""Test that flow streaming captures LLM chunks from crew execution."""
|
||||
|
||||
class TestFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
def run_crew(self) -> str:
|
||||
return "done"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.stream = True
|
||||
|
||||
original_kickoff = Flow.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Flow ",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="output!",
|
||||
),
|
||||
)
|
||||
return "done"
|
||||
|
||||
with patch.object(Flow, "kickoff", mock_kickoff_fn):
|
||||
streaming = flow.kickoff()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
chunks = list(streaming)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
contents = [c.content for c in chunks]
|
||||
assert "Flow " in contents
|
||||
assert "output!" in contents
|
||||
|
||||
def test_flow_kickoff_streaming_result_available(self) -> None:
|
||||
"""Test that flow result is available after iteration."""
|
||||
|
||||
class TestFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
def generate(self) -> str:
|
||||
return "flow result"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.stream = True
|
||||
|
||||
original_kickoff = Flow.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
return "flow result"
|
||||
|
||||
with patch.object(Flow, "kickoff", mock_kickoff_fn):
|
||||
streaming = flow.kickoff()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
_ = list(streaming)
|
||||
|
||||
result = streaming.result
|
||||
assert result == "flow result"
|
||||
|
||||
|
||||
class TestFlowKickoffStreamingAsync:
|
||||
"""Tests for Flow(stream=True).kickoff_async() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_streaming_async_returns_streaming_output(self) -> None:
|
||||
"""Test that flow kickoff_async with stream=True returns FlowStreamingOutput."""
|
||||
|
||||
class SimpleFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
async def generate(self) -> str:
|
||||
return "async result"
|
||||
|
||||
flow = SimpleFlow()
|
||||
flow.stream = True
|
||||
streaming = await flow.kickoff_async()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_kickoff_streaming_async_captures_chunks(self) -> None:
|
||||
"""Test that async flow streaming captures LLM chunks."""
|
||||
|
||||
class TestFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
async def run_crew(self) -> str:
|
||||
return "done"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.stream = True
|
||||
|
||||
original_kickoff = Flow.kickoff_async
|
||||
call_count = [0]
|
||||
|
||||
async def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return await original_kickoff(self, inputs)
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Async flow ",
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="stream!",
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
return "done"
|
||||
|
||||
with patch.object(Flow, "kickoff_async", mock_kickoff_fn):
|
||||
streaming = await flow.kickoff_async()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
chunks: list[StreamChunk] = []
|
||||
async for chunk in streaming:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
contents = [c.content for c in chunks]
|
||||
assert "Async flow " in contents
|
||||
assert "stream!" in contents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_kickoff_streaming_async_result_available(self) -> None:
|
||||
"""Test that async flow result is available after iteration."""
|
||||
|
||||
class TestFlow(Flow[dict[str, Any]]):
|
||||
@start()
|
||||
async def generate(self) -> str:
|
||||
return "async flow result"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.stream = True
|
||||
|
||||
original_kickoff = Flow.kickoff_async
|
||||
call_count = [0]
|
||||
|
||||
async def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return await original_kickoff(self, inputs)
|
||||
else:
|
||||
return "async flow result"
|
||||
|
||||
with patch.object(Flow, "kickoff_async", mock_kickoff_fn):
|
||||
streaming = await flow.kickoff_async()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
result = streaming.result
|
||||
assert result == "async flow result"
|
||||
|
||||
|
||||
class TestStreamingEdgeCases:
|
||||
"""Tests for edge cases in streaming functionality."""
|
||||
|
||||
def test_streaming_handles_exceptions(self, researcher: Agent, simple_task: Task) -> None:
|
||||
"""Test that streaming properly propagates exceptions."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
original_kickoff = Crew.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
raise ValueError("Test error")
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = crew.kickoff()
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
list(streaming)
|
||||
|
||||
def test_streaming_with_empty_content_chunks(self) -> None:
|
||||
"""Test streaming when LLM chunks have empty content."""
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "No streaming"
|
||||
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
chunks = list(streaming)
|
||||
|
||||
assert streaming.is_completed
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == ""
|
||||
|
||||
# Simulate what _finalize_streaming does
|
||||
streaming._set_result(mock_output)
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw == "No streaming"
|
||||
|
||||
def test_streaming_with_multiple_tasks(self, researcher: Agent) -> None:
|
||||
"""Test streaming with multiple tasks tracks task context."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First output",
|
||||
agent=researcher,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Multi-task output"
|
||||
|
||||
original_kickoff = Crew.kickoff
|
||||
call_count = [0]
|
||||
|
||||
def mock_kickoff_fn(self: Any, inputs: Any = None) -> Any:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return original_kickoff(self, inputs)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Task 1",
|
||||
task_name="First task",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
|
||||
with patch.object(Crew, "kickoff", mock_kickoff_fn):
|
||||
streaming = crew.kickoff()
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
chunks = list(streaming)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert streaming.is_completed
|
||||
|
||||
|
||||
class TestStreamingImports:
|
||||
"""Tests for correct imports of streaming types."""
|
||||
|
||||
def test_streaming_types_importable_from_types_module(self) -> None:
|
||||
"""Test that streaming types can be imported from crewai.types.streaming."""
|
||||
from crewai.types.streaming import (
|
||||
CrewStreamingOutput,
|
||||
FlowStreamingOutput,
|
||||
StreamChunk,
|
||||
StreamChunkType,
|
||||
ToolCallChunk,
|
||||
)
|
||||
|
||||
assert CrewStreamingOutput is not None
|
||||
assert FlowStreamingOutput is not None
|
||||
assert StreamChunk is not None
|
||||
assert StreamChunkType is not None
|
||||
assert ToolCallChunk is not None
|
||||
290
lib/crewai/tests/test_streaming_integration.py
Normal file
290
lib/crewai/tests/test_streaming_integration.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Integration tests for streaming with real LLM interactions using cassettes."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def researcher() -> Agent:
|
||||
"""Create a researcher agent for testing."""
|
||||
return Agent(
|
||||
role="Research Analyst",
|
||||
goal="Gather comprehensive information on topics",
|
||||
backstory="You are an experienced researcher with excellent analytical skills.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(researcher: Agent) -> Task:
|
||||
"""Create a simple research task."""
|
||||
return Task(
|
||||
description="Research the latest developments in {topic}",
|
||||
expected_output="A brief summary of recent developments",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
|
||||
class TestStreamingCrewIntegration:
|
||||
"""Integration tests for crew streaming that match documentation examples."""
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_basic_crew_streaming_from_docs(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test basic streaming example from documentation."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
streaming = crew.kickoff(inputs={"topic": "artificial intelligence"})
|
||||
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
|
||||
chunks = []
|
||||
for chunk in streaming:
|
||||
chunks.append(chunk.content)
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw is not None
|
||||
assert len(result.raw) > 0
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_with_chunk_context_from_docs(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test streaming with chunk context example from documentation."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
streaming = crew.kickoff(inputs={"topic": "AI"})
|
||||
|
||||
chunk_contexts = []
|
||||
for chunk in streaming:
|
||||
chunk_contexts.append(
|
||||
{
|
||||
"task_name": chunk.task_name,
|
||||
"task_index": chunk.task_index,
|
||||
"agent_role": chunk.agent_role,
|
||||
"content": chunk.content,
|
||||
"type": chunk.chunk_type,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(chunk_contexts) > 0
|
||||
assert all("agent_role" in ctx for ctx in chunk_contexts)
|
||||
|
||||
result = streaming.result
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_properties_from_docs(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test streaming properties example from documentation."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
streaming = crew.kickoff(inputs={"topic": "AI"})
|
||||
|
||||
for _ in streaming:
|
||||
pass
|
||||
|
||||
assert streaming.is_completed is True
|
||||
full_text = streaming.get_full_text()
|
||||
assert len(full_text) > 0
|
||||
assert len(streaming.chunks) > 0
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_from_docs(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test async streaming example from documentation."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
streaming = await crew.kickoff_async(inputs={"topic": "AI"})
|
||||
|
||||
assert isinstance(streaming, CrewStreamingOutput)
|
||||
|
||||
chunks = []
|
||||
async for chunk in streaming:
|
||||
chunks.append(chunk.content)
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
result = streaming.result
|
||||
assert result.raw is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_kickoff_for_each_streaming_from_docs(
|
||||
self, researcher: Agent, simple_task: Task
|
||||
) -> None:
|
||||
"""Test kickoff_for_each streaming example from documentation."""
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[simple_task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
inputs_list = [{"topic": "AI in healthcare"}, {"topic": "AI in finance"}]
|
||||
|
||||
streaming_outputs = crew.kickoff_for_each(inputs=inputs_list)
|
||||
|
||||
assert len(streaming_outputs) == 2
|
||||
assert all(isinstance(s, CrewStreamingOutput) for s in streaming_outputs)
|
||||
|
||||
results = []
|
||||
for streaming in streaming_outputs:
|
||||
for _ in streaming:
|
||||
pass
|
||||
|
||||
result = streaming.result
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(r.raw is not None for r in results)
|
||||
|
||||
|
||||
class TestStreamingFlowIntegration:
|
||||
"""Integration tests for flow streaming that match documentation examples."""
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_basic_flow_streaming_from_docs(self) -> None:
|
||||
"""Test basic flow streaming example from documentation."""
|
||||
|
||||
class ResearchFlow(Flow):
|
||||
stream = True
|
||||
|
||||
@start()
|
||||
def research_topic(self) -> str:
|
||||
researcher = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Research topics thoroughly",
|
||||
backstory="Expert researcher with analytical skills",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research AI trends and provide insights",
|
||||
expected_output="Detailed research findings",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
stream=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
streaming = crew.kickoff()
|
||||
for _ in streaming:
|
||||
pass
|
||||
return streaming.result.raw
|
||||
|
||||
flow = ResearchFlow()
|
||||
|
||||
streaming = flow.kickoff()
|
||||
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
|
||||
chunks = []
|
||||
for chunk in streaming:
|
||||
chunks.append(chunk.content)
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
result = streaming.result
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_flow_streaming_properties_from_docs(self) -> None:
|
||||
"""Test flow streaming properties example from documentation."""
|
||||
|
||||
class SimpleFlow(Flow):
|
||||
stream = True
|
||||
|
||||
@start()
|
||||
def execute(self) -> str:
|
||||
return "Flow result"
|
||||
|
||||
flow = SimpleFlow()
|
||||
streaming = flow.kickoff()
|
||||
|
||||
for _ in streaming:
|
||||
pass
|
||||
|
||||
assert streaming.is_completed is True
|
||||
streaming.get_full_text()
|
||||
assert len(streaming.chunks) >= 0
|
||||
|
||||
result = streaming.result
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_flow_streaming_from_docs(self) -> None:
|
||||
"""Test async flow streaming example from documentation."""
|
||||
|
||||
class AsyncResearchFlow(Flow):
|
||||
stream = True
|
||||
|
||||
@start()
|
||||
def research(self) -> str:
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research AI",
|
||||
expected_output="Research findings",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher], tasks=[task], stream=True, verbose=False)
|
||||
streaming = crew.kickoff()
|
||||
for _ in streaming:
|
||||
pass
|
||||
return streaming.result.raw
|
||||
|
||||
flow = AsyncResearchFlow()
|
||||
|
||||
streaming = await flow.kickoff_async()
|
||||
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
|
||||
chunks = []
|
||||
async for chunk in streaming:
|
||||
chunks.append(chunk.content)
|
||||
|
||||
result = streaming.result
|
||||
assert result is not None
|
||||
@@ -26,6 +26,7 @@ from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -47,7 +48,7 @@ from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
from ..utils import wait_for_event_handlers
|
||||
@@ -703,6 +704,156 @@ def test_flow_emits_method_execution_failed_event():
|
||||
assert received_events[0].error == error
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_unstructured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes unstructured (dict) state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state["counter"] = 1
|
||||
self.state["message"] = "test"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state["counter"] = 2
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
# Find the events for each method
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
# Verify state is included and is a dict
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert "id" in begin_event.state # Auto-generated ID
|
||||
|
||||
# Verify state from begin method is captured in process event
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "test"
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_structured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes structured (BaseModel) state and serializes it properly."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
counter: int = 0
|
||||
message: str = ""
|
||||
items: list[str] = []
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.counter = 1
|
||||
self.state.message = "initial"
|
||||
self.state.items = ["a", "b"]
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.counter += 1
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert begin_event.state["counter"] == 0 # Initial state
|
||||
assert begin_event.state["message"] == ""
|
||||
assert begin_event.state["items"] == []
|
||||
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "initial"
|
||||
assert process_event.state["items"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_flow_method_execution_finished_includes_serialized_state():
|
||||
"""Test that MethodExecutionFinishedEvent includes properly serialized state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
result: str = ""
|
||||
completed: bool = False
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_finished(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.result = "begin done"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.result = "process done"
|
||||
self.state.completed = True
|
||||
return "final_result"
|
||||
|
||||
flow = TestFlow()
|
||||
final_output = flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution finished event"
|
||||
)
|
||||
|
||||
begin_finished = next(e for e in received_events if e.method_name == "begin")
|
||||
process_finished = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_finished.state is not None
|
||||
assert isinstance(begin_finished.state, dict)
|
||||
assert begin_finished.state["result"] == "begin done"
|
||||
assert begin_finished.state["completed"] is False
|
||||
assert begin_finished.result == "started"
|
||||
|
||||
# Verify process finished event has final state and result
|
||||
assert process_finished.state is not None
|
||||
assert isinstance(process_finished.state, dict)
|
||||
assert process_finished.state["result"] == "process done"
|
||||
assert process_finished.state["completed"] is True
|
||||
assert process_finished.result == "final_result"
|
||||
assert final_output == "final_result"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_emits_call_started_event():
|
||||
received_events = []
|
||||
|
||||
Reference in New Issue
Block a user