mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-11 05:22:41 +00:00
Compare commits
3 Commits
gl/refacto
...
devin/1771
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6bdad873a0 | ||
|
|
51754899a2 | ||
|
|
71b4f8402a |
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"json5~=0.10.0",
|
||||
"portalocker~=2.7.0",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"httpx~=0.28.1",
|
||||
"mcp~=1.26.0",
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
|
||||
@@ -6,8 +6,10 @@ and memory management.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
@@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
] = []
|
||||
for call_id, func_name, func_args in parsed_calls:
|
||||
original_tool = original_tools_by_name.get(func_name)
|
||||
execution_plan.append((call_id, func_name, func_args, original_tool))
|
||||
execution_plan.append(
|
||||
(call_id, func_name, func_args, original_tool)
|
||||
)
|
||||
|
||||
self._append_assistant_tool_calls_message(
|
||||
[
|
||||
@@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
|
||||
max_workers = min(8, len(execution_plan))
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||
execution_plan
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
@@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return tool_finish
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
reasoning_message = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
}
|
||||
@@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
elif (
|
||||
should_execute
|
||||
and original_tool
|
||||
and getattr(original_tool, "max_usage_count", None) is not None
|
||||
and getattr(original_tool, "current_usage_count", 0)
|
||||
>= original_tool.max_usage_count
|
||||
and (max_count := getattr(original_tool, "max_usage_count", None))
|
||||
is not None
|
||||
and getattr(original_tool, "current_usage_count", 0) >= max_count
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
@@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
and hasattr(original_tool, "cache_function")
|
||||
and callable(original_tool.cache_function)
|
||||
):
|
||||
should_cache = original_tool.cache_function(args_dict, raw_result)
|
||||
should_cache = original_tool.cache_function(
|
||||
args_dict, raw_result
|
||||
)
|
||||
if should_cache:
|
||||
self.tools_handler.cache.add(
|
||||
tool=func_name, input=input_str, output=raw_result
|
||||
)
|
||||
|
||||
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
result = (
|
||||
str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
if self.task:
|
||||
@@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
formatted_answer: Current agent response.
|
||||
"""
|
||||
if self.step_callback:
|
||||
self.step_callback(formatted_answer)
|
||||
cb_result = self.step_callback(formatted_answer)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
def _append_message(
|
||||
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
|
||||
|
||||
@@ -2,8 +2,8 @@ import time
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
import webbrowser
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
@@ -98,7 +98,7 @@ class AuthenticationCommand:
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
@@ -130,7 +130,7 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
token_data = response.json()
|
||||
@@ -149,7 +149,7 @@ class AuthenticationCommand:
|
||||
return
|
||||
|
||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||
raise requests.HTTPError(
|
||||
raise httpx.HTTPError(
|
||||
token_data.get("error_description") or token_data.get("error")
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
@@ -30,16 +31,16 @@ class PlusAPIMixin:
|
||||
console.print("Run 'crewai login' to sign up/login.", style="bold green")
|
||||
raise SystemExit from None
|
||||
|
||||
def _validate_response(self, response: requests.Response) -> None:
|
||||
def _validate_response(self, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The response from the Plus API
|
||||
response (httpx.Response): The response from the Plus API
|
||||
"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except (JSONDecodeError, ValueError):
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
console.print(
|
||||
"Failed to parse response from Enterprise API failed. Details:",
|
||||
style="bold red",
|
||||
@@ -62,7 +63,7 @@ class PlusAPIMixin:
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.ok:
|
||||
if not response.is_success:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError, RequestException
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
@@ -47,12 +47,12 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
"X-Crewai-Version": get_crewai_version(),
|
||||
}
|
||||
response = requests.get(oauth_endpoint, timeout=30, headers=headers)
|
||||
response = httpx.get(oauth_endpoint, timeout=30, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
oauth_config = response.json()
|
||||
except JSONDecodeError as e:
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
|
||||
|
||||
self._validate_oauth_config(oauth_config)
|
||||
@@ -62,7 +62,7 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
)
|
||||
return cast(dict[str, Any], oauth_config)
|
||||
|
||||
except RequestException as e:
|
||||
except httpx.HTTPError as e:
|
||||
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching OAuth2 configuration: {e!s}") from e
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from requests import HTTPError
|
||||
from httpx import HTTPStatusError
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
@@ -10,11 +10,11 @@ console = Console()
|
||||
|
||||
|
||||
class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def list(self):
|
||||
def list(self) -> None:
|
||||
try:
|
||||
response = self.plus_api_client.get_organizations()
|
||||
response.raise_for_status()
|
||||
@@ -33,7 +33,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
table.add_row(org["name"], org["uuid"])
|
||||
|
||||
console.print(table)
|
||||
except HTTPError as e:
|
||||
except HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
@@ -50,7 +50,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def switch(self, org_id):
|
||||
def switch(self, org_id: str) -> None:
|
||||
try:
|
||||
response = self.plus_api_client.get_organizations()
|
||||
response.raise_for_status()
|
||||
@@ -72,7 +72,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
f"Successfully switched to {org['name']} ({org['uuid']})",
|
||||
style="bold green",
|
||||
)
|
||||
except HTTPError as e:
|
||||
except HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
@@ -87,7 +87,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(f"Failed to switch organization: {e!s}", style="bold red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def current(self):
|
||||
def current(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
console.print(
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
@@ -43,16 +42,16 @@ class PlusAPI:
|
||||
|
||||
def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
session = requests.Session()
|
||||
session.trust_env = False
|
||||
return session.request(method, url, headers=self.headers, **kwargs)
|
||||
verify = kwargs.pop("verify", True)
|
||||
with httpx.Client(trust_env=False, verify=verify) as client:
|
||||
return client.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def login_to_tool_repository(self) -> requests.Response:
|
||||
def login_to_tool_repository(self) -> httpx.Response:
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
||||
|
||||
def get_tool(self, handle: str) -> requests.Response:
|
||||
def get_tool(self, handle: str) -> httpx.Response:
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
async def get_agent(self, handle: str) -> httpx.Response:
|
||||
@@ -68,7 +67,7 @@ class PlusAPI:
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": is_public,
|
||||
@@ -79,54 +78,52 @@ class PlusAPI:
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
def deploy_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> requests.Response:
|
||||
def crew_status_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def crew_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> requests.Response:
|
||||
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
|
||||
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
|
||||
|
||||
def list_crews(self) -> requests.Response:
|
||||
def list_crews(self) -> httpx.Response:
|
||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||
|
||||
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
|
||||
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
|
||||
def get_organizations(self) -> requests.Response:
|
||||
def get_organizations(self) -> httpx.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
|
||||
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches",
|
||||
@@ -136,7 +133,7 @@ class PlusAPI:
|
||||
|
||||
def initialize_ephemeral_trace_batch(
|
||||
self, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
||||
@@ -145,7 +142,7 @@ class PlusAPI:
|
||||
|
||||
def send_trace_events(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
@@ -155,7 +152,7 @@ class PlusAPI:
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
@@ -165,7 +162,7 @@ class PlusAPI:
|
||||
|
||||
def finalize_trace_batch(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
@@ -175,7 +172,7 @@ class PlusAPI:
|
||||
|
||||
def finalize_ephemeral_trace_batch(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
@@ -185,7 +182,7 @@ class PlusAPI:
|
||||
|
||||
def mark_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
@@ -193,13 +190,11 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_triggers(self) -> requests.Response:
|
||||
def get_triggers(self) -> httpx.Response:
|
||||
"""Get all available triggers from integrations."""
|
||||
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
|
||||
|
||||
def get_trigger_payload(
|
||||
self, app_slug: str, trigger_slug: str
|
||||
) -> requests.Response:
|
||||
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
|
||||
"""Get sample payload for a specific trigger."""
|
||||
return self._make_request(
|
||||
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
@@ -165,20 +165,20 @@ def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except requests.RequestException as e:
|
||||
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
|
||||
def download_data(response: requests.Response) -> dict[str, Any]:
|
||||
def download_data(response: httpx.Response) -> dict[str, Any]:
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
@@ -194,7 +194,7 @@ def download_data(response: requests.Response) -> dict[str, Any]:
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as bar:
|
||||
for chunk in response.iter_content(block_size):
|
||||
for chunk in response.iter_bytes(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
from_cache = cast(bool, execution_result["from_cache"])
|
||||
original_tool = execution_result["original_tool"]
|
||||
|
||||
tool_message: LLMMessage = {
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": func_name,
|
||||
@@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
formatted_answer: Current agent response.
|
||||
"""
|
||||
if self.step_callback:
|
||||
self.step_callback(formatted_answer)
|
||||
cb_result = self.step_callback(formatted_answer)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
def _append_message_to_state(
|
||||
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
@@ -624,11 +625,15 @@ class Task(BaseModel):
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
cb_result = self.callback(self.output)
|
||||
if inspect.isawaitable(cb_result):
|
||||
await cb_result
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
cb_result = crew.task_callback(self.output)
|
||||
if inspect.isawaitable(cb_result):
|
||||
await cb_result
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
@@ -722,11 +727,15 @@ class Task(BaseModel):
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
cb_result = self.callback(self.output)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
cb_result = crew.task_callback(self.output)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||
@@ -501,7 +502,9 @@ def handle_agent_action_core(
|
||||
- TODO: Remove messages parameter and its usage.
|
||||
"""
|
||||
if step_callback:
|
||||
step_callback(tool_result)
|
||||
cb_result = step_callback(tool_result)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
formatted_answer.text += f"\nObservation: {tool_result.result}"
|
||||
formatted_answer.result = tool_result.result
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -291,6 +291,46 @@ class TestAsyncAgentExecutor:
|
||||
assert max_concurrent > 1, f"Expected concurrent execution, max concurrent was {max_concurrent}"
|
||||
|
||||
|
||||
class TestInvokeStepCallback:
|
||||
"""Tests for _invoke_step_callback with sync and async callbacks."""
|
||||
|
||||
def test_invoke_step_callback_with_sync_callback(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that a sync step callback is called normally."""
|
||||
callback = Mock()
|
||||
executor.step_callback = callback
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
callback.assert_called_once_with(answer)
|
||||
|
||||
def test_invoke_step_callback_with_async_callback(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that an async step callback is awaited via asyncio.run."""
|
||||
async_callback = AsyncMock()
|
||||
executor.step_callback = async_callback
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
with patch("crewai.agents.crew_agent_executor.asyncio.run") as mock_run:
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
async_callback.assert_called_once_with(answer)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
def test_invoke_step_callback_with_none(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that no error is raised when step_callback is None."""
|
||||
executor.step_callback = None
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
# Should not raise
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
|
||||
class TestAsyncLLMResponseHelper:
|
||||
"""Tests for aget_llm_response helper function."""
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import httpx
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
@@ -220,7 +220,7 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
@@ -256,7 +256,7 @@ class TestAuthenticationCommand:
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
def test_poll_for_token_success(self, mock_console_print, mock_post):
|
||||
mock_response_success = MagicMock()
|
||||
@@ -305,7 +305,7 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
|
||||
mock_response_pending = MagicMock()
|
||||
@@ -324,7 +324,7 @@ class TestAuthenticationCommand:
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
def test_poll_for_token_error(self, mock_post):
|
||||
"""Test the method to poll for token (error path)."""
|
||||
# Setup mock to return error
|
||||
@@ -338,5 +338,5 @@ class TestAuthenticationCommand:
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
@@ -4,10 +4,11 @@ from io import StringIO
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
from crewai.cli.utils import parse_toml
|
||||
from requests.exceptions import JSONDecodeError
|
||||
|
||||
|
||||
class TestDeployCommand(unittest.TestCase):
|
||||
@@ -37,18 +38,18 @@ class TestDeployCommand(unittest.TestCase):
|
||||
DeployCommand()
|
||||
|
||||
def test_validate_response_successful_response(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"message": "Success"}
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.is_success = True
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
assert fake_out.getvalue() == ""
|
||||
|
||||
def test_validate_response_json_decode_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0)
|
||||
mock_response.status_code = 500
|
||||
mock_response.content = b"Invalid JSON"
|
||||
|
||||
@@ -64,13 +65,13 @@ class TestDeployCommand(unittest.TestCase):
|
||||
assert "Response:\nInvalid JSON" in output
|
||||
|
||||
def test_validate_response_422_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {
|
||||
"field1": ["Error message 1"],
|
||||
"field2": ["Error message 2"],
|
||||
}
|
||||
mock_response.status_code = 422
|
||||
mock_response.ok = False
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
@@ -84,10 +85,10 @@ class TestDeployCommand(unittest.TestCase):
|
||||
assert "Field2 Error message 2" in output
|
||||
|
||||
def test_validate_response_other_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"error": "Something went wrong"}
|
||||
mock_response.status_code = 500
|
||||
mock_response.ok = False
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
|
||||
@@ -3,8 +3,9 @@ import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
@@ -25,7 +26,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.httpx.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_successful_configuration(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
@@ -73,19 +74,23 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
self.assertEqual(call_args[0], key)
|
||||
self.assertEqual(call_args[1], value)
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.httpx.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_http_error_handling(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"404 Not Found",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(404),
|
||||
)
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.httpx.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
@@ -93,13 +98,13 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.side_effect = JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.httpx.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
@@ -115,7 +120,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.httpx.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_settings_update_error(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
from crewai.cli.cli import org_list, switch, current
|
||||
@@ -115,7 +115,7 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
def test_list_organizations_api_error(self, mock_console):
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.side_effect = (
|
||||
requests.exceptions.RequestException("API Error")
|
||||
httpx.HTTPError("API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
@@ -201,8 +201,10 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_list_organizations_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = requests.exceptions.HTTPError(
|
||||
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||
mock_http_error = httpx.HTTPStatusError(
|
||||
"401 Client Error: Unauthorized",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
@@ -219,8 +221,10 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_switch_organization_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = requests.exceptions.HTTPError(
|
||||
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||
mock_http_error = httpx.HTTPStatusError(
|
||||
"401 Client Error: Unauthorized",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
|
||||
@@ -33,9 +33,9 @@ class TestPlusAPI(unittest.TestCase):
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
def assert_request_with_org_id(
|
||||
self, mock_make_request, method: str, endpoint: str, **kwargs
|
||||
self, mock_client_instance, method: str, endpoint: str, **kwargs
|
||||
):
|
||||
mock_make_request.assert_called_once_with(
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
method,
|
||||
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
|
||||
headers={
|
||||
@@ -49,24 +49,25 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
def test_login_to_tool_repository_with_org_uuid(
|
||||
self, mock_make_request, mock_settings_class
|
||||
self, mock_client_class, mock_settings_class
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -82,23 +83,23 @@ class TestPlusAPI(unittest.TestCase):
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
# Set up mock response
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -130,18 +131,18 @@ class TestPlusAPI(unittest.TestCase):
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
# Set up mock response
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
@@ -153,7 +154,6 @@ class TestPlusAPI(unittest.TestCase):
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
# Expected params including organization_uuid
|
||||
expected_params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
@@ -164,7 +164,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -195,20 +195,19 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.requests.Session")
|
||||
def test_make_request(self, mock_session):
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
def test_make_request(self, mock_client_class):
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
mock_session_instance = mock_session.return_value
|
||||
mock_session_instance.request.return_value = mock_response
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api._make_request("GET", "test_endpoint")
|
||||
|
||||
mock_session.assert_called_once()
|
||||
mock_session_instance.request.assert_called_once_with(
|
||||
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
|
||||
)
|
||||
mock_session_instance.trust_env = False
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
|
||||
@@ -351,7 +351,7 @@ def test_publish_api_error(
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "Internal Server Error"}
|
||||
mock_response.ok = False
|
||||
mock_response.is_success = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
with raises(SystemExit):
|
||||
|
||||
@@ -3,7 +3,7 @@ import subprocess
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestTriggersCommand(unittest.TestCase):
|
||||
|
||||
@patch("crewai.cli.triggers.main.console.print")
|
||||
def test_list_triggers_success(self, mock_console_print):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
@@ -50,7 +50,7 @@ class TestTriggersCommand(unittest.TestCase):
|
||||
|
||||
@patch("crewai.cli.triggers.main.console.print")
|
||||
def test_list_triggers_no_apps(self, mock_console_print):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {"apps": []}
|
||||
@@ -81,7 +81,7 @@ class TestTriggersCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.triggers.main.console.print")
|
||||
@patch.object(TriggersCommand, "_run_crew_with_payload")
|
||||
def test_execute_with_trigger_success(self, mock_run_crew, mock_console_print):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
@@ -99,7 +99,7 @@ class TestTriggersCommand(unittest.TestCase):
|
||||
|
||||
@patch("crewai.cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_not_found(self, mock_console_print):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {"error": "Trigger not found"}
|
||||
self.mock_client.get_trigger_payload.return_value = mock_response
|
||||
@@ -159,7 +159,7 @@ class TestTriggersCommand(unittest.TestCase):
|
||||
|
||||
@patch("crewai.cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_with_default_error_message(self, mock_console_print):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {}
|
||||
self.mock_client.get_trigger_payload.return_value = mock_response
|
||||
|
||||
443
lib/crewai/tests/test_ci_check_classifier.py
Normal file
443
lib/crewai/tests/test_ci_check_classifier.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""Tests for the deterministic CI check-state classifier.
|
||||
|
||||
Covers every category defined in the acceptance criteria for issue #4576:
|
||||
- passed
|
||||
- failed
|
||||
- pending
|
||||
- no_checks
|
||||
- policy_blocked
|
||||
|
||||
Also validates that source check metadata is retained for audit/review,
|
||||
and that the output contract (JSON shape) is stable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dynamic import of the classifier script from ``scripts/``
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCRIPT_PATH = Path(__file__).resolve().parents[3] / "scripts" / "classify_ci_checks.py"
|
||||
|
||||
_spec = importlib.util.spec_from_file_location("classify_ci_checks", _SCRIPT_PATH)
|
||||
assert _spec is not None and _spec.loader is not None
|
||||
_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_mod)
|
||||
|
||||
classify = _mod.classify
|
||||
main = _mod.main
|
||||
PASSED = _mod.PASSED
|
||||
FAILED = _mod.FAILED
|
||||
PENDING = _mod.PENDING
|
||||
NO_CHECKS = _mod.NO_CHECKS
|
||||
POLICY_BLOCKED = _mod.POLICY_BLOCKED
|
||||
ALL_STATES = _mod.ALL_STATES
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_check_run(
|
||||
name: str = "ci",
|
||||
status: str = "completed",
|
||||
conclusion: str = "success",
|
||||
started_at: str = "2026-01-01T00:00:00Z",
|
||||
completed_at: str = "2026-01-01T00:05:00Z",
|
||||
) -> dict[str, Any]:
|
||||
"""Build a minimal GitHub check-run dict."""
|
||||
return {
|
||||
"name": name,
|
||||
"status": status,
|
||||
"conclusion": conclusion,
|
||||
"started_at": started_at,
|
||||
"completed_at": completed_at,
|
||||
}
|
||||
|
||||
|
||||
def _make_commit_status(
|
||||
context: str = "ci/status",
|
||||
state: str = "success",
|
||||
updated_at: str = "2026-01-01T00:05:00Z",
|
||||
) -> dict[str, str]:
|
||||
"""Build a minimal GitHub commit-status dict."""
|
||||
return {
|
||||
"context": context,
|
||||
"state": state,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Category: no_checks
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestNoChecks:
|
||||
"""When there are zero check runs and zero statuses -> ``no_checks``."""
|
||||
|
||||
def test_empty_check_runs_list(self) -> None:
|
||||
result = classify({"check_runs": []})
|
||||
assert result["state"] == NO_CHECKS
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_empty_payload(self) -> None:
|
||||
result = classify({})
|
||||
assert result["state"] == NO_CHECKS
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_empty_check_runs_and_statuses(self) -> None:
|
||||
result = classify({"check_runs": [], "statuses": []})
|
||||
assert result["state"] == NO_CHECKS
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_summary_message(self) -> None:
|
||||
result = classify({"check_runs": []})
|
||||
assert "No CI checks found" in result["summary"]
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Category: passed
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestPassed:
|
||||
"""All checks completed successfully -> ``passed``."""
|
||||
|
||||
def test_single_success(self) -> None:
|
||||
result = classify({"check_runs": [_make_check_run()]})
|
||||
assert result["state"] == PASSED
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_multiple_successes(self) -> None:
|
||||
runs = [
|
||||
_make_check_run(name="lint"),
|
||||
_make_check_run(name="tests (3.10)"),
|
||||
_make_check_run(name="tests (3.12)"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == PASSED
|
||||
assert result["total"] == 3
|
||||
|
||||
def test_neutral_conclusion_counts_as_passed(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="neutral")]}
|
||||
)
|
||||
assert result["state"] == PASSED
|
||||
|
||||
def test_skipped_conclusion_counts_as_passed(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="skipped")]}
|
||||
)
|
||||
assert result["state"] == PASSED
|
||||
|
||||
def test_commit_status_success(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [], "statuses": [_make_commit_status(state="success")]}
|
||||
)
|
||||
assert result["state"] == PASSED
|
||||
|
||||
def test_mixed_check_runs_and_statuses_all_pass(self) -> None:
|
||||
result = classify({
|
||||
"check_runs": [_make_check_run(name="build")],
|
||||
"statuses": [_make_commit_status(context="deploy", state="success")],
|
||||
})
|
||||
assert result["state"] == PASSED
|
||||
assert result["total"] == 2
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Category: failed
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestFailed:
|
||||
"""At least one check failed -> ``failed``."""
|
||||
|
||||
def test_single_failure(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="failure")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_timed_out(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="timed_out")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_cancelled(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="cancelled")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_startup_failure(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="startup_failure")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_failure_among_successes(self) -> None:
|
||||
runs = [
|
||||
_make_check_run(name="lint"),
|
||||
_make_check_run(name="tests", conclusion="failure"),
|
||||
_make_check_run(name="build"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == FAILED
|
||||
assert result["total"] == 3
|
||||
|
||||
def test_commit_status_failure(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [], "statuses": [_make_commit_status(state="failure")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_commit_status_error(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [], "statuses": [_make_commit_status(state="error")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_failed_overrides_pending(self) -> None:
|
||||
"""Failed takes precedence over pending."""
|
||||
runs = [
|
||||
_make_check_run(name="lint", status="in_progress", conclusion=""),
|
||||
_make_check_run(name="tests", conclusion="failure"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == FAILED
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Category: pending
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestPending:
|
||||
"""At least one check still in progress or queued -> ``pending``."""
|
||||
|
||||
def test_queued(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(status="queued", conclusion="")]}
|
||||
)
|
||||
assert result["state"] == PENDING
|
||||
|
||||
def test_in_progress(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(status="in_progress", conclusion="")]}
|
||||
)
|
||||
assert result["state"] == PENDING
|
||||
|
||||
def test_waiting(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(status="waiting", conclusion="")]}
|
||||
)
|
||||
assert result["state"] == PENDING
|
||||
|
||||
def test_pending_among_successes(self) -> None:
|
||||
runs = [
|
||||
_make_check_run(name="lint"),
|
||||
_make_check_run(name="tests", status="in_progress", conclusion=""),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == PENDING
|
||||
|
||||
def test_commit_status_pending(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [], "statuses": [_make_commit_status(state="pending")]}
|
||||
)
|
||||
assert result["state"] == PENDING
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Category: policy_blocked
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestPolicyBlocked:
|
||||
"""A check requires manual action -> ``policy_blocked``."""
|
||||
|
||||
def test_action_required(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="action_required")]}
|
||||
)
|
||||
assert result["state"] == POLICY_BLOCKED
|
||||
|
||||
def test_policy_blocked_overrides_failed(self) -> None:
|
||||
"""policy_blocked has highest priority after no_checks."""
|
||||
runs = [
|
||||
_make_check_run(name="lint", conclusion="failure"),
|
||||
_make_check_run(name="review", conclusion="action_required"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == POLICY_BLOCKED
|
||||
|
||||
def test_policy_blocked_overrides_pending(self) -> None:
|
||||
runs = [
|
||||
_make_check_run(name="build", status="in_progress", conclusion=""),
|
||||
_make_check_run(name="policy", conclusion="action_required"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == POLICY_BLOCKED
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Output contract / metadata retention
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestOutputContract:
|
||||
"""The JSON output has a stable shape and retains source metadata."""
|
||||
|
||||
def test_result_keys(self) -> None:
|
||||
result = classify({"check_runs": [_make_check_run()]})
|
||||
assert set(result.keys()) == {"state", "total", "summary", "checks"}
|
||||
|
||||
def test_state_is_a_known_value(self) -> None:
|
||||
for conclusion in ("success", "failure", "action_required"):
|
||||
result = classify({"check_runs": [_make_check_run(conclusion=conclusion)]})
|
||||
assert result["state"] in ALL_STATES
|
||||
|
||||
def test_check_metadata_retained(self) -> None:
|
||||
cr = _make_check_run(name="my-job", conclusion="success")
|
||||
result = classify({"check_runs": [cr]})
|
||||
meta = result["checks"][0]
|
||||
assert meta["name"] == "my-job"
|
||||
assert meta["status"] == "completed"
|
||||
assert meta["conclusion"] == "success"
|
||||
assert meta["started_at"] == "2026-01-01T00:00:00Z"
|
||||
assert meta["completed_at"] == "2026-01-01T00:05:00Z"
|
||||
|
||||
def test_commit_status_metadata_retained(self) -> None:
|
||||
cs = _make_commit_status(context="ci/deploy", state="success")
|
||||
result = classify({"check_runs": [], "statuses": [cs]})
|
||||
meta = result["checks"][0]
|
||||
assert meta["name"] == "ci/deploy"
|
||||
assert meta["status"] == "success"
|
||||
|
||||
def test_result_is_json_serialisable(self) -> None:
|
||||
result = classify({
|
||||
"check_runs": [_make_check_run()],
|
||||
"statuses": [_make_commit_status()],
|
||||
})
|
||||
roundtripped = json.loads(json.dumps(result))
|
||||
assert roundtripped == result
|
||||
|
||||
def test_total_matches_checks_length(self) -> None:
|
||||
runs = [_make_check_run(name=f"job-{i}") for i in range(5)]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["total"] == len(result["checks"]) == 5
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# CLI entry-point (main)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Test the ``main()`` function that wraps classify for CLI use."""
|
||||
|
||||
def test_exit_code_passed(self, tmp_path: Path) -> None:
|
||||
payload = {"check_runs": [_make_check_run()]}
|
||||
f = tmp_path / "input.json"
|
||||
f.write_text(json.dumps(payload))
|
||||
assert main([str(f)]) == 0
|
||||
|
||||
def test_exit_code_failed(self, tmp_path: Path) -> None:
|
||||
payload = {"check_runs": [_make_check_run(conclusion="failure")]}
|
||||
f = tmp_path / "input.json"
|
||||
f.write_text(json.dumps(payload))
|
||||
assert main([str(f)]) == 1
|
||||
|
||||
def test_exit_code_pending(self, tmp_path: Path) -> None:
|
||||
payload = {"check_runs": [_make_check_run(status="queued", conclusion="")]}
|
||||
f = tmp_path / "input.json"
|
||||
f.write_text(json.dumps(payload))
|
||||
assert main([str(f)]) == 2
|
||||
|
||||
def test_exit_code_no_checks(self, tmp_path: Path) -> None:
|
||||
payload = {"check_runs": []}
|
||||
f = tmp_path / "input.json"
|
||||
f.write_text(json.dumps(payload))
|
||||
assert main([str(f)]) == 2
|
||||
|
||||
def test_exit_code_policy_blocked(self, tmp_path: Path) -> None:
|
||||
payload = {"check_runs": [_make_check_run(conclusion="action_required")]}
|
||||
f = tmp_path / "input.json"
|
||||
f.write_text(json.dumps(payload))
|
||||
assert main([str(f)]) == 1
|
||||
|
||||
def test_invalid_json_returns_error(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "bad.json"
|
||||
f.write_text("NOT JSON")
|
||||
assert main([str(f)]) == 1
|
||||
|
||||
def test_missing_file_returns_error(self) -> None:
|
||||
assert main(["/nonexistent/path.json"]) == 1
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Edge cases
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Boundary and edge-case scenarios."""
|
||||
|
||||
def test_check_run_with_missing_fields(self) -> None:
|
||||
"""Gracefully handles check runs that omit optional fields."""
|
||||
result = classify({"check_runs": [{"status": "completed", "conclusion": "success"}]})
|
||||
assert result["state"] == PASSED
|
||||
meta = result["checks"][0]
|
||||
assert meta["name"] == ""
|
||||
assert meta["started_at"] == ""
|
||||
|
||||
def test_case_insensitive_conclusion(self) -> None:
|
||||
"""Conclusion strings are normalised to lowercase."""
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="FAILURE")]}
|
||||
)
|
||||
assert result["state"] == FAILED
|
||||
|
||||
def test_case_insensitive_status(self) -> None:
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(status="IN_PROGRESS", conclusion="")]}
|
||||
)
|
||||
assert result["state"] == PENDING
|
||||
|
||||
def test_stale_conclusion_is_not_failure(self) -> None:
|
||||
"""``stale`` is a non-blocking conclusion."""
|
||||
result = classify(
|
||||
{"check_runs": [_make_check_run(conclusion="stale")]}
|
||||
)
|
||||
assert result["state"] == PASSED
|
||||
|
||||
def test_large_number_of_checks(self) -> None:
|
||||
"""Classifier handles many checks without error."""
|
||||
runs = [_make_check_run(name=f"job-{i}") for i in range(500)]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == PASSED
|
||||
assert result["total"] == 500
|
||||
|
||||
def test_mixed_all_states(self) -> None:
|
||||
"""When all state types are present, policy_blocked wins."""
|
||||
runs = [
|
||||
_make_check_run(name="pass", conclusion="success"),
|
||||
_make_check_run(name="fail", conclusion="failure"),
|
||||
_make_check_run(name="pend", status="queued", conclusion=""),
|
||||
_make_check_run(name="block", conclusion="action_required"),
|
||||
]
|
||||
result = classify({"check_runs": runs})
|
||||
assert result["state"] == POLICY_BLOCKED
|
||||
assert result["total"] == 4
|
||||
293
scripts/classify_ci_checks.py
Normal file
293
scripts/classify_ci_checks.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Deterministic CI check-state classifier for CrewAI PR triage.
|
||||
|
||||
Normalizes raw GitHub CI check data into a deterministic JSON contract
|
||||
so that cross-repo planning and execution can rely on a stable,
|
||||
machine-readable CI-state output.
|
||||
|
||||
Categories
|
||||
----------
|
||||
- ``passed`` -- every check completed successfully
|
||||
- ``failed`` -- at least one check failed, timed-out, or was cancelled
|
||||
- ``pending`` -- at least one check is still queued or in progress
|
||||
- ``no_checks`` -- the PR has no associated check runs or commit statuses
|
||||
- ``policy_blocked`` -- at least one check requires manual action (e.g. review)
|
||||
|
||||
Usage
|
||||
-----
|
||||
Pipe JSON from the GitHub Checks API (or a compatible payload) into stdin::
|
||||
|
||||
gh api repos/{owner}/{repo}/commits/{ref}/check-runs | python scripts/classify_ci_checks.py
|
||||
|
||||
Or supply a file path as the first positional argument::
|
||||
|
||||
python scripts/classify_ci_checks.py checks.json
|
||||
|
||||
The script prints a single JSON object to stdout and exits with code 0 for
|
||||
``passed``, 1 for ``failed``/``policy_blocked``, and 2 for ``pending``/``no_checks``.
|
||||
|
||||
Example output::
|
||||
|
||||
{
|
||||
"state": "failed",
|
||||
"total": 3,
|
||||
"summary": "1 failed, 2 passed (3 total)",
|
||||
"checks": [
|
||||
{
|
||||
"name": "tests (3.12)",
|
||||
"status": "completed",
|
||||
"conclusion": "failure",
|
||||
"started_at": "2026-02-24T10:00:00Z",
|
||||
"completed_at": "2026-02-24T10:05:00Z"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public state constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PASSED: str = "passed"
|
||||
FAILED: str = "failed"
|
||||
PENDING: str = "pending"
|
||||
NO_CHECKS: str = "no_checks"
|
||||
POLICY_BLOCKED: str = "policy_blocked"
|
||||
|
||||
ALL_STATES: frozenset[str] = frozenset(
|
||||
{PASSED, FAILED, PENDING, NO_CHECKS, POLICY_BLOCKED}
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal mapping helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# GitHub check-run conclusions that map to *failed*
|
||||
_FAILED_CONCLUSIONS: frozenset[str] = frozenset(
|
||||
{"failure", "timed_out", "cancelled", "startup_failure"}
|
||||
)
|
||||
|
||||
# GitHub check-run conclusions that map to *policy_blocked*
|
||||
_POLICY_CONCLUSIONS: frozenset[str] = frozenset({"action_required"})
|
||||
|
||||
# GitHub check-run statuses that map to *pending*
|
||||
_PENDING_STATUSES: frozenset[str] = frozenset({"queued", "in_progress", "waiting"})
|
||||
|
||||
# GitHub commit-status states that map to *failed*
|
||||
_FAILED_COMMIT_STATES: frozenset[str] = frozenset({"failure", "error"})
|
||||
|
||||
# GitHub commit-status states that map to *pending*
|
||||
_PENDING_COMMIT_STATES: frozenset[str] = frozenset({"pending"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core classifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_check_metadata(check: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract audit-relevant metadata from a single check run or status.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
check:
|
||||
A single check-run or commit-status object from the GitHub API.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
Normalized metadata dict with name, status, conclusion, and timestamps.
|
||||
"""
|
||||
return {
|
||||
"name": check.get("name") or check.get("context") or "",
|
||||
"status": check.get("status") or check.get("state") or "",
|
||||
"conclusion": check.get("conclusion") or "",
|
||||
"started_at": check.get("started_at") or "",
|
||||
"completed_at": check.get("completed_at") or check.get("updated_at") or "",
|
||||
}
|
||||
|
||||
|
||||
def classify(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Classify CI check data into a deterministic state.
|
||||
|
||||
Accepts the JSON body returned by the GitHub ``check-runs`` endpoint
|
||||
(which wraps runs in ``{"total_count": N, "check_runs": [...]}``),
|
||||
a plain list of check-run objects, or a combined payload that also
|
||||
includes commit statuses under the ``statuses`` key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
payload:
|
||||
Raw GitHub API response or a dict with ``check_runs`` and/or
|
||||
``statuses`` lists.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
Deterministic JSON-serialisable result with keys ``state``,
|
||||
``total``, ``summary``, and ``checks`` (source metadata).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> result = classify({"check_runs": [], "statuses": []})
|
||||
>>> result["state"]
|
||||
'no_checks'
|
||||
|
||||
>>> result = classify({
|
||||
... "check_runs": [
|
||||
... {"name": "lint", "status": "completed", "conclusion": "success"}
|
||||
... ]
|
||||
... })
|
||||
>>> result["state"]
|
||||
'passed'
|
||||
"""
|
||||
# Normalise input: accept top-level list or wrapped object
|
||||
if isinstance(payload.get("check_runs"), list):
|
||||
check_runs: list[dict[str, Any]] = payload["check_runs"]
|
||||
elif isinstance(payload, list): # type: ignore[arg-type]
|
||||
check_runs = payload # type: ignore[assignment]
|
||||
else:
|
||||
check_runs = []
|
||||
|
||||
statuses: list[dict[str, Any]] = payload.get("statuses", []) if isinstance(payload, dict) else []
|
||||
|
||||
all_metadata: list[dict[str, Any]] = []
|
||||
has_policy_blocked = False
|
||||
has_failed = False
|
||||
has_pending = False
|
||||
|
||||
# --- Classify check runs ---
|
||||
for cr in check_runs:
|
||||
all_metadata.append(_extract_check_metadata(cr))
|
||||
status = (cr.get("status") or "").lower()
|
||||
conclusion = (cr.get("conclusion") or "").lower()
|
||||
|
||||
if conclusion in _POLICY_CONCLUSIONS:
|
||||
has_policy_blocked = True
|
||||
elif conclusion in _FAILED_CONCLUSIONS:
|
||||
has_failed = True
|
||||
elif status in _PENDING_STATUSES:
|
||||
has_pending = True
|
||||
# completed + success/neutral/skipped/stale → not a problem
|
||||
|
||||
# --- Classify commit statuses ---
|
||||
for cs in statuses:
|
||||
all_metadata.append(_extract_check_metadata(cs))
|
||||
state = (cs.get("state") or "").lower()
|
||||
|
||||
if state in _FAILED_COMMIT_STATES:
|
||||
has_failed = True
|
||||
elif state in _PENDING_COMMIT_STATES:
|
||||
has_pending = True
|
||||
|
||||
# --- Determine aggregate state (priority order) ---
|
||||
total = len(all_metadata)
|
||||
|
||||
if total == 0:
|
||||
state = NO_CHECKS
|
||||
elif has_policy_blocked:
|
||||
state = POLICY_BLOCKED
|
||||
elif has_failed:
|
||||
state = FAILED
|
||||
elif has_pending:
|
||||
state = PENDING
|
||||
else:
|
||||
state = PASSED
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"total": total,
|
||||
"summary": _build_summary(state, all_metadata),
|
||||
"checks": all_metadata,
|
||||
}
|
||||
|
||||
|
||||
def _build_summary(state: str, checks: list[dict[str, Any]]) -> str:
|
||||
"""Build a human-readable one-line summary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state:
|
||||
The classified state string.
|
||||
checks:
|
||||
List of normalized check metadata dicts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str:
|
||||
Human-readable summary string.
|
||||
"""
|
||||
total = len(checks)
|
||||
if total == 0:
|
||||
return "No CI checks found"
|
||||
|
||||
# Count by conclusion/status bucket
|
||||
counts: dict[str, int] = {}
|
||||
for c in checks:
|
||||
conclusion = c.get("conclusion", "")
|
||||
status = c.get("status", "")
|
||||
# Use conclusion if available, otherwise status
|
||||
bucket = conclusion if conclusion else status
|
||||
if not bucket:
|
||||
bucket = "unknown"
|
||||
counts[bucket] = counts.get(bucket, 0) + 1
|
||||
|
||||
parts = [f"{v} {k}" for k, v in sorted(counts.items())]
|
||||
return f"{', '.join(parts)} ({total} total)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry-point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Exit codes aligned to state severity
|
||||
_EXIT_CODES: dict[str, int] = {
|
||||
PASSED: 0,
|
||||
FAILED: 1,
|
||||
POLICY_BLOCKED: 1,
|
||||
PENDING: 2,
|
||||
NO_CHECKS: 2,
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
"""CLI entry-point: read JSON from *stdin* or a file and classify.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
argv:
|
||||
Command-line arguments (default: ``sys.argv[1:]``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
Exit code (0 = passed, 1 = failed/blocked, 2 = pending/no checks).
|
||||
"""
|
||||
args = argv if argv is not None else sys.argv[1:]
|
||||
|
||||
try:
|
||||
if args:
|
||||
with open(args[0]) as fh:
|
||||
raw = fh.read()
|
||||
else:
|
||||
raw = sys.stdin.read()
|
||||
|
||||
payload = json.loads(raw)
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
print(json.dumps({"error": str(exc)}), file=sys.stderr) # noqa: T201
|
||||
return 1
|
||||
|
||||
result = classify(payload)
|
||||
print(json.dumps(result, indent=2)) # noqa: T201
|
||||
return _EXIT_CODES.get(result["state"], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1096,6 +1096,7 @@ dependencies = [
|
||||
{ name = "appdirs" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "click" },
|
||||
{ name = "httpx" },
|
||||
{ name = "instructor" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "json5" },
|
||||
@@ -1195,6 +1196,7 @@ requires-dist = [
|
||||
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
|
||||
{ name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" },
|
||||
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" },
|
||||
{ name = "httpx", specifier = "~=0.28.1" },
|
||||
{ name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" },
|
||||
{ name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" },
|
||||
{ name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" },
|
||||
|
||||
Reference in New Issue
Block a user