Compare commits

..

2 Commits

Author SHA1 Message Date
github-actions[bot]
aaa478159d chore: update tool specifications 2026-02-20 10:03:20 +00:00
Greyson LaLonde
ddcfffe3ab refactor: simplify platform integration token resolution
Remove unused platform_context context manager and env var fallback
from context module. Token resolves from context var or env var via
default_factory on the tool field. Replace custom __init__ with
model_validator and use sanitize_tool_name.
2026-02-20 05:01:52 -05:00
29 changed files with 191 additions and 526 deletions

View File

@@ -6,8 +6,10 @@ from typing import Any
from crewai.tools import BaseTool
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from pydantic import Field, create_model
from crewai.utilities.string_utils import sanitize_tool_name
from pydantic import Field, create_model, model_validator
import requests
from typing_extensions import Self
from crewai_tools.tools.crewai_platform_tools.misc import (
get_platform_api_base_url,
@@ -20,34 +22,27 @@ class CrewAIPlatformActionTool(BaseTool):
action_schema: dict[str, Any] = Field(
default_factory=dict, description="The schema of the action"
)
integration_token: str | None = Field(
default_factory=get_platform_integration_token,
)
def __init__(
self,
description: str,
action_name: str,
action_schema: dict[str, Any],
):
parameters = action_schema.get("function", {}).get("parameters", {})
@model_validator(mode="after")
def _build_args_schema(self) -> Self:
parameters = self.action_schema.get("function", {}).get("parameters", {})
if parameters and parameters.get("properties"):
try:
if "title" not in parameters:
parameters = {**parameters, "title": f"{action_name}Schema"}
parameters = {**parameters, "title": f"{self.action_name}Schema"}
if "type" not in parameters:
parameters = {**parameters, "type": "object"}
args_schema = create_model_from_schema(parameters)
self.args_schema = create_model_from_schema(parameters)
except Exception:
args_schema = create_model(f"{action_name}Schema")
self.args_schema = create_model(f"{self.action_name}Schema")
else:
args_schema = create_model(f"{action_name}Schema")
super().__init__(
name=action_name.lower().replace(" ", "_"),
description=description,
args_schema=args_schema,
)
self.action_name = action_name
self.action_schema = action_schema
self.args_schema = create_model(f"{self.action_name}Schema")
if not self.name:
self.name = sanitize_tool_name(self.action_name)
return self
def _run(self, **kwargs: Any) -> str:
try:
@@ -58,9 +53,8 @@ class CrewAIPlatformActionTool(BaseTool):
api_url = (
f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
)
token = get_platform_integration_token()
headers = {
"Authorization": f"Bearer {token}",
"Authorization": f"Bearer {self.integration_token}",
"Content-Type": "application/json",
}
payload = {

View File

@@ -6,6 +6,7 @@ from types import TracebackType
from typing import Any
from crewai.tools import BaseTool
from crewai.utilities.string_utils import sanitize_tool_name
import requests
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
@@ -30,6 +31,7 @@ class CrewaiPlatformToolBuilder:
self._apps = apps
self._actions_schema: dict[str, dict[str, Any]] = {}
self._tools: list[BaseTool] | None = None
self._integration_token = get_platform_integration_token()
def tools(self) -> list[BaseTool]:
"""Fetch actions and return built tools."""
@@ -41,7 +43,7 @@ class CrewaiPlatformToolBuilder:
def _fetch_actions(self) -> None:
"""Fetch action schemas from the platform API."""
actions_url = f"{get_platform_api_base_url()}/actions"
headers = {"Authorization": f"Bearer {get_platform_integration_token()}"}
headers = {"Authorization": f"Bearer {self._integration_token}"}
try:
response = requests.get(
@@ -88,9 +90,11 @@ class CrewaiPlatformToolBuilder:
description = function_details.get("description", f"Execute {action_name}")
tool = CrewAIPlatformActionTool(
name=sanitize_tool_name(action_name),
description=description,
action_name=action_name,
action_schema=action_schema,
integration_token=self._integration_token,
)
tools.append(tool)

View File

@@ -1,5 +1,7 @@
import os
from crewai.context import get_platform_integration_token as _get_context_token
def get_platform_api_base_url() -> str:
"""Get the platform API base URL from environment or use default."""
@@ -7,11 +9,5 @@ def get_platform_api_base_url() -> str:
return f"{base_url}/crewai_plus/api/v1/integrations"
def get_platform_integration_token() -> str:
"""Get the platform API base URL from environment or use default."""
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or ""
if not token:
raise ValueError(
"No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable"
)
return token # TODO: Use context manager to get token
def get_platform_integration_token() -> str | None:
return _get_context_token() or os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")

View File

@@ -27,9 +27,10 @@ class TestCrewAIPlatformActionToolVerify:
def create_test_tool(self):
return CrewAIPlatformActionTool(
name="test_action",
description="Test action tool",
action_name="test_action",
action_schema=self.action_schema
action_schema=self.action_schema,
)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"}, clear=True)

View File

@@ -107,12 +107,10 @@ class TestCrewaiPlatformToolBuilder(unittest.TestCase):
)
def test_fetch_actions_no_token(self):
builder = CrewaiPlatformToolBuilder(apps=["github"])
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
builder._fetch_actions()
assert "No platform integration token found" in str(context.exception)
builder = CrewaiPlatformToolBuilder(apps=["github"])
assert builder._integration_token is None
assert builder.tools() == []
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(

View File

@@ -110,6 +110,5 @@ class TestCrewaiPlatformTools(unittest.TestCase):
def test_crewai_platform_tools_no_token(self):
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
CrewaiPlatformTools(apps=["github"])
assert "No platform integration token found" in str(context.exception)
tools = CrewaiPlatformTools(apps=["github"])
assert tools == []

View File

@@ -20117,18 +20117,6 @@
"humanized_name": "Web Automation Tool",
"init_params_schema": {
"$defs": {
"AvailableModel": {
"enum": [
"gpt-4o",
"gpt-4o-mini",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"computer-use-preview",
"gemini-2.0-flash"
],
"title": "AvailableModel",
"type": "string"
},
"EnvVar": {
"properties": {
"default": {
@@ -20206,17 +20194,6 @@
"default": null,
"title": "Model Api Key"
},
"model_name": {
"anyOf": [
{
"$ref": "#/$defs/AvailableModel"
},
{
"type": "null"
}
],
"default": "claude-3-7-sonnet-latest"
},
"project_id": {
"anyOf": [
{

View File

@@ -38,7 +38,6 @@ dependencies = [
"json5~=0.10.0",
"portalocker~=2.7.0",
"pydantic-settings~=2.10.1",
"httpx~=0.28.1",
"mcp~=1.26.0",
"uv~=0.9.13",
"aiosqlite~=0.21.0",

View File

@@ -6,10 +6,8 @@ and memory management.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -738,9 +736,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
] = []
for call_id, func_name, func_args in parsed_calls:
original_tool = original_tools_by_name.get(func_name)
execution_plan.append(
(call_id, func_name, func_args, original_tool)
)
execution_plan.append((call_id, func_name, func_args, original_tool))
self._append_assistant_tool_calls_message(
[
@@ -750,9 +746,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
max_workers = min(8, len(execution_plan))
ordered_results: list[dict[str, Any] | None] = [None] * len(
execution_plan
)
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
@@ -809,7 +803,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return tool_finish
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message = {
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
@@ -914,9 +908,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
elif (
should_execute
and original_tool
and (max_count := getattr(original_tool, "max_usage_count", None))
is not None
and getattr(original_tool, "current_usage_count", 0) >= max_count
and getattr(original_tool, "max_usage_count", None) is not None
and getattr(original_tool, "current_usage_count", 0)
>= original_tool.max_usage_count
):
max_usage_reached = True
@@ -995,17 +989,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
and hasattr(original_tool, "cache_function")
and callable(original_tool.cache_function)
):
should_cache = original_tool.cache_function(
args_dict, raw_result
)
should_cache = original_tool.cache_function(args_dict, raw_result)
if should_cache:
self.tools_handler.cache.add(
tool=func_name, input=input_str, output=raw_result
)
result = (
str(raw_result) if not isinstance(raw_result, str) else raw_result
)
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
except Exception as e:
result = f"Error executing tool: {e}"
if self.task:
@@ -1500,9 +1490,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
self.step_callback(formatted_answer)
def _append_message(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -2,8 +2,8 @@ import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
import httpx
from pydantic import BaseModel, Field
import requests
from rich.console import Console
from crewai.cli.authentication.utils import validate_jwt_token
@@ -98,7 +98,7 @@ class AuthenticationCommand:
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = httpx.post(
response = requests.post(
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
@@ -130,7 +130,7 @@ class AuthenticationCommand:
attempts = 0
while True and attempts < 10:
response = httpx.post(
response = requests.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
@@ -149,7 +149,7 @@ class AuthenticationCommand:
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise httpx.HTTPError(
raise requests.HTTPError(
token_data.get("error_description") or token_data.get("error")
)

View File

@@ -1,6 +1,5 @@
import json
import httpx
import requests
from requests.exceptions import JSONDecodeError
from rich.console import Console
from crewai.cli.authentication.token import get_auth_token
@@ -31,16 +30,16 @@ class PlusAPIMixin:
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
def _validate_response(self, response: requests.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
response (httpx.Response): The response from the Plus API
response (requests.Response): The response from the Plus API
"""
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):
except (JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
@@ -63,7 +62,7 @@ class PlusAPIMixin:
)
raise SystemExit
if not response.is_success:
if not response.ok:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, cast
import httpx
import requests
from requests.exceptions import JSONDecodeError, RequestException
from rich.console import Console
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
@@ -47,12 +47,12 @@ class EnterpriseConfigureCommand(BaseCommand):
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
response = httpx.get(oauth_endpoint, timeout=30, headers=headers)
response = requests.get(oauth_endpoint, timeout=30, headers=headers)
response.raise_for_status()
try:
oauth_config = response.json()
except json.JSONDecodeError as e:
except JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
self._validate_oauth_config(oauth_config)
@@ -62,7 +62,7 @@ class EnterpriseConfigureCommand(BaseCommand):
)
return cast(dict[str, Any], oauth_config)
except httpx.HTTPError as e:
except RequestException as e:
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
except Exception as e:
raise ValueError(f"Error fetching OAuth2 configuration: {e!s}") from e

View File

@@ -1,4 +1,4 @@
from httpx import HTTPStatusError
from requests import HTTPError
from rich.console import Console
from rich.table import Table
@@ -10,11 +10,11 @@ console = Console()
class OrganizationCommand(BaseCommand, PlusAPIMixin):
def __init__(self) -> None:
def __init__(self):
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def list(self) -> None:
def list(self):
try:
response = self.plus_api_client.get_organizations()
response.raise_for_status()
@@ -33,7 +33,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
table.add_row(org["name"], org["uuid"])
console.print(table)
except HTTPStatusError as e:
except HTTPError as e:
if e.response.status_code == 401:
console.print(
"You are not logged in to any organization. Use 'crewai login' to login.",
@@ -50,7 +50,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
)
raise SystemExit(1) from e
def switch(self, org_id: str) -> None:
def switch(self, org_id):
try:
response = self.plus_api_client.get_organizations()
response.raise_for_status()
@@ -72,7 +72,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
f"Successfully switched to {org['name']} ({org['uuid']})",
style="bold green",
)
except HTTPStatusError as e:
except HTTPError as e:
if e.response.status_code == 401:
console.print(
"You are not logged in to any organization. Use 'crewai login' to login.",
@@ -87,7 +87,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
console.print(f"Failed to switch organization: {e!s}", style="bold red")
raise SystemExit(1) from e
def current(self) -> None:
def current(self):
settings = Settings()
if settings.org_uuid:
console.print(

View File

@@ -3,6 +3,7 @@ from typing import Any
from urllib.parse import urljoin
import httpx
import requests
from crewai.cli.config import Settings
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
@@ -42,16 +43,16 @@ class PlusAPI:
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> httpx.Response:
) -> requests.Response:
url = urljoin(self.base_url, endpoint)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
session = requests.Session()
session.trust_env = False
return session.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> httpx.Response:
def login_to_tool_repository(self) -> requests.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> httpx.Response:
def get_tool(self, handle: str) -> requests.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
@@ -67,7 +68,7 @@ class PlusAPI:
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> httpx.Response:
) -> requests.Response:
params = {
"handle": handle,
"public": is_public,
@@ -78,52 +79,54 @@ class PlusAPI:
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> httpx.Response:
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
def deploy_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> httpx.Response:
def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> httpx.Response:
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> httpx.Response:
def get_organizations(self) -> requests.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
@@ -133,7 +136,7 @@ class PlusAPI:
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
@@ -142,7 +145,7 @@ class PlusAPI:
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -152,7 +155,7 @@ class PlusAPI:
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -162,7 +165,7 @@ class PlusAPI:
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -172,7 +175,7 @@ class PlusAPI:
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -182,7 +185,7 @@ class PlusAPI:
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
@@ -190,11 +193,13 @@ class PlusAPI:
timeout=30,
)
def get_triggers(self) -> httpx.Response:
def get_triggers(self) -> requests.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
def get_trigger_payload(
self, app_slug: str, trigger_slug: str
) -> requests.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"

View File

@@ -8,7 +8,7 @@ from typing import Any
import certifi
import click
import httpx
import requests
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
@@ -165,20 +165,20 @@ def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except httpx.HTTPError as e:
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response: httpx.Response) -> dict[str, Any]:
def download_data(response: requests.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
@@ -194,7 +194,7 @@ def download_data(response: httpx.Response) -> dict[str, Any]:
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as bar:
for chunk in response.iter_bytes(block_size):
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
bar.update(len(chunk))

View File

@@ -1,8 +1,4 @@
from collections.abc import Generator
from contextlib import contextmanager
import contextvars
import os
from typing import Any
_platform_integration_token: contextvars.ContextVar[str | None] = (
@@ -10,39 +6,9 @@ _platform_integration_token: contextvars.ContextVar[str | None] = (
)
def set_platform_integration_token(integration_token: str) -> None:
"""Set the platform integration token in the current context.
Args:
integration_token: The integration token to set.
"""
_platform_integration_token.set(integration_token)
def get_platform_integration_token() -> str | None:
"""Get the platform integration token from the current context or environment.
Returns:
The integration token if set, otherwise None.
"""
token = _platform_integration_token.get()
if token is None:
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
return token
@contextmanager
def platform_context(integration_token: str) -> Generator[None, Any, None]:
"""Context manager to temporarily set the platform integration token.
Args:
integration_token: The integration token to set within the context.
"""
token = _platform_integration_token.set(integration_token)
try:
yield
finally:
_platform_integration_token.reset(token)
"""Get the platform integration token from the current context."""
return _platform_integration_token.get()
_current_task_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import inspect
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -780,7 +778,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_cache = cast(bool, execution_result["from_cache"])
original_tool = execution_result["original_tool"]
tool_message = {
tool_message: LLMMessage = {
"role": "tool",
"tool_call_id": call_id,
"name": func_name,
@@ -1360,9 +1358,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
self.step_callback(formatted_answer)
def _append_message_to_state(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
@@ -625,15 +624,11 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
cb_result = self.callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
cb_result = crew.task_callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
crew.task_callback(self.output)
if self.output_file:
content = (
@@ -727,15 +722,11 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
cb_result = self.callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
cb_result = crew.task_callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
crew.task_callback(self.output)
if self.output_file:
content = (

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
import concurrent.futures
import inspect
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -502,9 +501,7 @@ def handle_agent_action_core(
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
cb_result = step_callback(tool_result)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
step_callback(tool_result)
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result

View File

@@ -2,7 +2,7 @@
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -291,46 +291,6 @@ class TestAsyncAgentExecutor:
assert max_concurrent > 1, f"Expected concurrent execution, max concurrent was {max_concurrent}"
class TestInvokeStepCallback:
"""Tests for _invoke_step_callback with sync and async callbacks."""
def test_invoke_step_callback_with_sync_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that a sync step callback is called normally."""
callback = Mock()
executor.step_callback = callback
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
callback.assert_called_once_with(answer)
def test_invoke_step_callback_with_async_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that an async step callback is awaited via asyncio.run."""
async_callback = AsyncMock()
executor.step_callback = async_callback
answer = AgentFinish(thought="thinking", output="test", text="final")
with patch("crewai.agents.crew_agent_executor.asyncio.run") as mock_run:
executor._invoke_step_callback(answer)
async_callback.assert_called_once_with(answer)
mock_run.assert_called_once()
def test_invoke_step_callback_with_none(
self, executor: CrewAgentExecutor
) -> None:
"""Test that no error is raised when step_callback is None."""
executor.step_callback = None
answer = AgentFinish(thought="thinking", output="test", text="final")
# Should not raise
executor._invoke_step_callback(answer)
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import httpx
import requests
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
@@ -220,7 +220,7 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("requests.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
@@ -256,7 +256,7 @@ class TestAuthenticationCommand:
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai.cli.authentication.main.httpx.post")
@patch("requests.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
@@ -305,7 +305,7 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("requests.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
@@ -324,7 +324,7 @@ class TestAuthenticationCommand:
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("requests.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
@@ -338,5 +338,5 @@ class TestAuthenticationCommand:
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(httpx.HTTPError):
with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -4,11 +4,10 @@ from io import StringIO
from unittest.mock import MagicMock, Mock, patch
import pytest
import json
import httpx
import requests
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.utils import parse_toml
from requests.exceptions import JSONDecodeError
class TestDeployCommand(unittest.TestCase):
@@ -38,18 +37,18 @@ class TestDeployCommand(unittest.TestCase):
DeployCommand()
def test_validate_response_successful_response(self):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"message": "Success"}
mock_response.status_code = 200
mock_response.is_success = True
mock_response.ok = True
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._validate_response(mock_response)
assert fake_out.getvalue() == ""
def test_validate_response_json_decode_error(self):
mock_response = Mock(spec=httpx.Response)
mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0)
mock_response = Mock(spec=requests.Response)
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
mock_response.status_code = 500
mock_response.content = b"Invalid JSON"
@@ -65,13 +64,13 @@ class TestDeployCommand(unittest.TestCase):
assert "Response:\nInvalid JSON" in output
def test_validate_response_422_error(self):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {
"field1": ["Error message 1"],
"field2": ["Error message 2"],
}
mock_response.status_code = 422
mock_response.is_success = False
mock_response.ok = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):
@@ -85,10 +84,10 @@ class TestDeployCommand(unittest.TestCase):
assert "Field2 Error message 2" in output
def test_validate_response_other_error(self):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"error": "Something went wrong"}
mock_response.status_code = 500
mock_response.is_success = False
mock_response.ok = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):

View File

@@ -3,9 +3,8 @@ import unittest
from pathlib import Path
from unittest.mock import Mock, patch
import json
import httpx
import requests
from requests.exceptions import JSONDecodeError
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.settings.main import SettingsCommand
@@ -26,7 +25,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_successful_configuration(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -74,23 +73,19 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
self.assertEqual(call_args[0], key)
self.assertEqual(call_args[1], value)
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_http_error_handling(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"404 Not Found",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(404),
)
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -98,13 +93,13 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_response.json.side_effect = JSONDecodeError("Invalid JSON", "", 0)
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -120,7 +115,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_settings_update_error(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, call
import pytest
from click.testing import CliRunner
import httpx
import requests
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.cli import org_list, switch, current
@@ -115,7 +115,7 @@ class TestOrganizationCommand(unittest.TestCase):
def test_list_organizations_api_error(self, mock_console):
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.side_effect = (
httpx.HTTPError("API Error")
requests.exceptions.RequestException("API Error")
)
with pytest.raises(SystemExit):
@@ -201,10 +201,8 @@ class TestOrganizationCommand(unittest.TestCase):
@patch("crewai.cli.organization.main.console")
def test_list_organizations_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = httpx.HTTPStatusError(
"401 Client Error: Unauthorized",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(401),
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
)
mock_response.raise_for_status.side_effect = mock_http_error
@@ -221,10 +219,8 @@ class TestOrganizationCommand(unittest.TestCase):
@patch("crewai.cli.organization.main.console")
def test_switch_organization_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = httpx.HTTPStatusError(
"401 Client Error: Unauthorized",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(401),
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
)
mock_response.raise_for_status.side_effect = mock_http_error

View File

@@ -33,9 +33,9 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_client_instance, method: str, endpoint: str, **kwargs
self, mock_make_request, method: str, endpoint: str, **kwargs
):
mock_client_instance.request.assert_called_once_with(
mock_make_request.assert_called_once_with(
method,
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
headers={
@@ -49,25 +49,24 @@ class TestPlusAPI(unittest.TestCase):
)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
@patch("requests.Session.request")
def test_login_to_tool_repository_with_org_uuid(
self, mock_client_class, mock_settings_class
self, mock_make_request, mock_settings_class
):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@@ -83,23 +82,23 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
@patch("requests.Session.request")
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
# Set up mock response
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@@ -131,18 +130,18 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
@patch("requests.Session.request")
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
# Set up mock response
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
@@ -154,6 +153,7 @@ class TestPlusAPI(unittest.TestCase):
handle, public, version, description, encoded_file
)
# Expected params including organization_uuid
expected_params = {
"handle": handle,
"public": public,
@@ -164,7 +164,7 @@ class TestPlusAPI(unittest.TestCase):
}
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
)
self.assertEqual(response, mock_response)
@@ -195,19 +195,20 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
@patch("crewai.cli.plus_api.requests.Session")
def test_make_request(self, mock_session):
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
mock_session_instance = mock_session.return_value
mock_session_instance.request.return_value = mock_response
response = self.api._make_request("GET", "test_endpoint")
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
mock_client_instance.request.assert_called_once_with(
mock_session.assert_called_once()
mock_session_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
)
mock_session_instance.trust_env = False
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")

View File

@@ -351,7 +351,7 @@ def test_publish_api_error(
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}
mock_response.is_success = False
mock_response.ok = False
mock_publish.return_value = mock_response
with raises(SystemExit):

View File

@@ -3,7 +3,7 @@ import subprocess
import unittest
from unittest.mock import Mock, patch
import httpx
import requests
from crewai.cli.triggers.main import TriggersCommand
@@ -21,7 +21,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_list_triggers_success(self, mock_console_print):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {
@@ -50,7 +50,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_list_triggers_no_apps(self, mock_console_print):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {"apps": []}
@@ -81,7 +81,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
@patch.object(TriggersCommand, "_run_crew_with_payload")
def test_execute_with_trigger_success(self, mock_run_crew, mock_console_print):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {
@@ -99,7 +99,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_execute_with_trigger_not_found(self, mock_console_print):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.status_code = 404
mock_response.json.return_value = {"error": "Trigger not found"}
self.mock_client.get_trigger_payload.return_value = mock_response
@@ -159,7 +159,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_execute_with_trigger_with_default_error_message(self, mock_console_print):
mock_response = Mock(spec=httpx.Response)
mock_response = Mock(spec=requests.Response)
mock_response.status_code = 404
mock_response.json.return_value = {}
self.mock_client.get_trigger_payload.return_value = mock_response

View File

@@ -1,14 +1,8 @@
# ruff: noqa: S105
import os
from unittest.mock import patch
import pytest
from crewai.context import (
_platform_integration_token,
get_platform_integration_token,
platform_context,
set_platform_integration_token,
)
@@ -19,203 +13,15 @@ class TestPlatformIntegrationToken:
def teardown_method(self):
_platform_integration_token.set(None)
@patch.dict(os.environ, {}, clear=True)
def test_set_platform_integration_token(self):
test_token = "test-token-123"
def test_set_and_get(self):
assert get_platform_integration_token() is None
_platform_integration_token.set("test-token-123")
assert get_platform_integration_token() == "test-token-123"
def test_returns_none_when_not_set(self):
assert get_platform_integration_token() is None
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
def test_get_platform_integration_token_from_context_var(self):
test_token = "context-var-token"
_platform_integration_token.set(test_token)
assert get_platform_integration_token() == test_token
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
def test_get_platform_integration_token_from_env_var(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() == "env-token-456"
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
def test_context_var_takes_precedence_over_env_var(self):
context_token = "context-token"
set_platform_integration_token(context_token)
assert get_platform_integration_token() == context_token
@patch.dict(os.environ, {}, clear=True)
def test_get_platform_integration_token_returns_none_when_not_set(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() is None
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_basic_usage(self):
test_token = "context-manager-token"
assert get_platform_integration_token() is None
with platform_context(test_token):
assert get_platform_integration_token() == test_token
assert get_platform_integration_token() is None
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_nested_contexts(self):
"""Test nested platform_context context managers."""
outer_token = "outer-token"
inner_token = "inner-token"
assert get_platform_integration_token() is None
with platform_context(outer_token):
assert get_platform_integration_token() == outer_token
with platform_context(inner_token):
assert get_platform_integration_token() == inner_token
assert get_platform_integration_token() == outer_token
assert get_platform_integration_token() is None
def test_platform_context_manager_preserves_existing_token(self):
"""Test that platform_context preserves existing token when exiting."""
initial_token = "initial-token"
context_token = "context-token"
set_platform_integration_token(initial_token)
assert get_platform_integration_token() == initial_token
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == initial_token
def test_platform_context_manager_exception_handling(self):
"""Test that platform_context properly resets token even when exception occurs."""
initial_token = "initial-token"
context_token = "context-token"
set_platform_integration_token(initial_token)
with pytest.raises(ValueError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise ValueError("Test exception")
assert get_platform_integration_token() == initial_token
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_with_none_initial_state(self):
"""Test platform_context when initial state is None."""
context_token = "context-token"
assert get_platform_integration_token() is None
with pytest.raises(RuntimeError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise RuntimeError("Test exception")
assert get_platform_integration_token() is None
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
def test_platform_context_with_env_fallback(self):
"""Test platform_context interaction with environment variable fallback."""
context_token = "context-token"
assert get_platform_integration_token() == "env-backup"
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == "env-backup"
@patch.dict(os.environ, {}, clear=True)
def test_multiple_sequential_context_managers(self):
"""Test multiple sequential uses of platform_context."""
token1 = "token-1"
token2 = "token-2"
token3 = "token-3"
with platform_context(token1):
assert get_platform_integration_token() == token1
assert get_platform_integration_token() is None
with platform_context(token2):
assert get_platform_integration_token() == token2
assert get_platform_integration_token() is None
with platform_context(token3):
assert get_platform_integration_token() == token3
assert get_platform_integration_token() is None
def test_empty_string_token(self):
empty_token = ""
set_platform_integration_token(empty_token)
assert get_platform_integration_token() == ""
with platform_context(empty_token):
assert get_platform_integration_token() == ""
def test_special_characters_in_token(self):
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
set_platform_integration_token(special_token)
assert get_platform_integration_token() == special_token
with platform_context(special_token):
assert get_platform_integration_token() == special_token
def test_very_long_token(self):
long_token = "a" * 10000
set_platform_integration_token(long_token)
assert get_platform_integration_token() == long_token
with platform_context(long_token):
assert get_platform_integration_token() == long_token
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
def test_empty_env_var(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() == ""
@patch("crewai.context.os.getenv")
def test_env_var_access_error_handling(self, mock_getenv):
mock_getenv.side_effect = OSError("Environment access error")
with pytest.raises(OSError):
get_platform_integration_token()
@patch.dict(os.environ, {}, clear=True)
def test_context_var_isolation_between_tests(self):
"""Test that context variable changes don't leak between test methods."""
test_token = "isolation-test-token"
assert get_platform_integration_token() is None
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
def test_context_manager_return_value(self):
"""Test that platform_context can be used in with statement with return value."""
test_token = "return-value-token"
with platform_context(test_token):
assert get_platform_integration_token() == test_token
with platform_context(test_token) as ctx:
assert ctx is None
assert get_platform_integration_token() == test_token
def test_overwrite(self):
_platform_integration_token.set("first")
_platform_integration_token.set("second")
assert get_platform_integration_token() == "second"

2
uv.lock generated
View File

@@ -1096,7 +1096,6 @@ dependencies = [
{ name = "appdirs" },
{ name = "chromadb" },
{ name = "click" },
{ name = "httpx" },
{ name = "instructor" },
{ name = "json-repair" },
{ name = "json5" },
@@ -1196,7 +1195,6 @@ requires-dist = [
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
{ name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" },
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" },
{ name = "httpx", specifier = "~=0.28.1" },
{ name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" },
{ name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" },
{ name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" },