Compare commits

..

2 Commits

Author SHA1 Message Date
Rip&Tear
afa4d01c22 Merge branch 'main' into codex/fix-codeql-alert-28 2026-02-11 13:43:46 +08:00
theCyberTech
78ad77e619 fix: harden Azure endpoint validation for CodeQL alert 28 2026-02-11 13:28:03 +08:00
6 changed files with 82 additions and 26 deletions

View File

@@ -14,18 +14,13 @@ paths-ignore:
- "lib/crewai/src/crewai/experimental/a2a/**"
paths:
# Include GitHub Actions workflows/composite actions for CodeQL actions analysis
- ".github/workflows/**"
- ".github/actions/**"
# Include all Python source code from workspace packages
- "lib/crewai/src/**"
- "lib/crewai-tools/src/**"
- "lib/crewai-files/src/**"
- "lib/devtools/src/**"
# Include tests (but exclude cassettes via paths-ignore)
- "lib/crewai/tests/**"
- "lib/crewai-tools/tests/**"
- "lib/crewai-files/tests/**"
- "lib/devtools/tests/**"
# Configure specific queries or packs if needed

View File

@@ -69,7 +69,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -98,6 +98,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -33,11 +33,8 @@ def test_brave_tool_search(mock_get, brave_tool):
mock_get.return_value.json.return_value = mock_response
result = brave_tool.run(query="test")
data = json.loads(result)
assert isinstance(data, list)
assert len(data) >= 1
assert data[0]["title"] == "Test Title"
assert data[0]["url"] == "http://test.com"
assert "Test Title" in result
assert "http://test.com" in result
@patch("requests.get")

View File

@@ -1,10 +1,9 @@
import re
import sys
from urllib.parse import urlparse
from unittest.mock import MagicMock, patch
import pytest
# Create mock classes that will be used by our fixture
class MockStagehandModule:
def __init__(self):
@@ -172,14 +171,8 @@ def test_navigate_command(mock_run, stagehand_tool):
)
# Assertions
assert "Successfully navigated to " in result
assert "https://example.com" in result
# Extract URL from result string and check its host
# Example result: "Successfully navigated to https://example.com"
url_match = re.search(r"https?://[^\s]+", result)
assert url_match is not None
parsed = urlparse(url_match.group(0))
assert parsed.hostname == "example.com"
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict
from urllib.parse import urlparse
from pydantic import BaseModel
from typing_extensions import Self
@@ -175,11 +176,51 @@ class AzureCompletion(BaseLLM):
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
self.is_azure_openai_endpoint = self._is_azure_openai_deployment_endpoint(
self.endpoint
)
@staticmethod
def _parse_endpoint_url(endpoint: str):
parsed_endpoint = urlparse(endpoint)
if parsed_endpoint.hostname:
return parsed_endpoint
# Support endpoint values without a URL scheme.
return urlparse(f"https://{endpoint}")
@staticmethod
def _is_azure_openai_hostname(endpoint: str) -> bool:
parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint)
hostname = parsed_endpoint.hostname or ""
labels = [label for label in hostname.lower().split(".") if label]
return len(labels) >= 3 and labels[-3:] == ["openai", "azure", "com"]
@staticmethod
def _get_endpoint_path_segments(endpoint: str) -> list[str]:
parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint)
return [segment for segment in parsed_endpoint.path.split("/") if segment]
@staticmethod
def _is_azure_openai_deployment_endpoint(endpoint: str) -> bool:
if not AzureCompletion._is_azure_openai_hostname(endpoint):
return False
path_segments = AzureCompletion._get_endpoint_path_segments(endpoint)
return len(path_segments) >= 3 and path_segments[:2] == [
"openai",
"deployments",
]
@staticmethod
def _is_azure_openai_deployments_collection(endpoint: str) -> bool:
if not AzureCompletion._is_azure_openai_hostname(endpoint):
return False
path_segments = AzureCompletion._get_endpoint_path_segments(endpoint)
return path_segments == ["openai", "deployments"]
@staticmethod
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.
@@ -194,10 +235,12 @@ class AzureCompletion(BaseLLM):
Returns:
Validated and potentially corrected endpoint URL
"""
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
if AzureCompletion._is_azure_openai_hostname(
endpoint
) and not AzureCompletion._is_azure_openai_deployment_endpoint(endpoint):
endpoint = endpoint.rstrip("/")
if not endpoint.endswith("/openai/deployments"):
if not AzureCompletion._is_azure_openai_deployments_collection(endpoint):
deployment_name = model.replace("azure/", "")
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")

View File

@@ -958,6 +958,34 @@ def test_azure_endpoint_detection_flags():
assert llm_other.is_azure_openai_endpoint == False
def test_azure_endpoint_detection_ignores_spoofed_urls():
"""
Test that endpoint detection does not trust spoofed host/path substrings
"""
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
"AZURE_ENDPOINT": (
"https://evil.example.com/?redirect="
"https://test.openai.azure.com/openai/deployments/gpt-4"
),
}):
llm_query_spoof = LLM(model="azure/gpt-4")
assert llm_query_spoof.is_azure_openai_endpoint == False
assert "model" in llm_query_spoof._prepare_completion_params(
messages=[{"role": "user", "content": "test"}]
)
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
"AZURE_ENDPOINT": "https://test.openai.azure.com.evil/openai/deployments/gpt-4",
}):
llm_host_spoof = LLM(model="azure/gpt-4")
assert llm_host_spoof.is_azure_openai_endpoint == False
assert "model" in llm_host_spoof._prepare_completion_params(
messages=[{"role": "user", "content": "test"}]
)
def test_azure_improved_error_messages():
"""
Test that improved error messages are provided for common HTTP errors