diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index e7fd80844..abc0b54ae 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -4,6 +4,7 @@ import json import logging import os from typing import TYPE_CHECKING, Any, TypedDict +from urllib.parse import urlparse from pydantic import BaseModel from typing_extensions import Self @@ -175,11 +176,51 @@ class AzureCompletion(BaseLLM): prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] ) - self.is_azure_openai_endpoint = ( - "openai.azure.com" in self.endpoint - and "/openai/deployments/" in self.endpoint + self.is_azure_openai_endpoint = self._is_azure_openai_deployment_endpoint( + self.endpoint ) + @staticmethod + def _parse_endpoint_url(endpoint: str): + parsed_endpoint = urlparse(endpoint) + if parsed_endpoint.hostname: + return parsed_endpoint + + # Support endpoint values without a URL scheme. + return urlparse(f"https://{endpoint}") + + @staticmethod + def _is_azure_openai_hostname(endpoint: str) -> bool: + parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint) + hostname = parsed_endpoint.hostname or "" + labels = [label for label in hostname.lower().split(".") if label] + + return len(labels) >= 3 and labels[-3:] == ["openai", "azure", "com"] + + @staticmethod + def _get_endpoint_path_segments(endpoint: str) -> list[str]: + parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint) + return [segment for segment in parsed_endpoint.path.split("/") if segment] + + @staticmethod + def _is_azure_openai_deployment_endpoint(endpoint: str) -> bool: + if not AzureCompletion._is_azure_openai_hostname(endpoint): + return False + + path_segments = AzureCompletion._get_endpoint_path_segments(endpoint) + return len(path_segments) >= 3 and path_segments[:2] == [ + "openai", + "deployments", + ] + + @staticmethod + def _is_azure_openai_deployments_collection(endpoint: str) -> bool: + if not AzureCompletion._is_azure_openai_hostname(endpoint): + return False + + path_segments = AzureCompletion._get_endpoint_path_segments(endpoint) + return path_segments == ["openai", "deployments"] + @staticmethod def _validate_and_fix_endpoint(endpoint: str, model: str) -> str: """Validate and fix Azure endpoint URL format. @@ -194,10 +235,12 @@ class AzureCompletion(BaseLLM): Returns: Validated and potentially corrected endpoint URL """ - if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint: + if AzureCompletion._is_azure_openai_hostname( + endpoint + ) and not AzureCompletion._is_azure_openai_deployment_endpoint(endpoint): endpoint = endpoint.rstrip("/") - if not endpoint.endswith("/openai/deployments"): + if not AzureCompletion._is_azure_openai_deployments_collection(endpoint): deployment_name = model.replace("azure/", "") endpoint = f"{endpoint}/openai/deployments/{deployment_name}" logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}") diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index 17a01bb56..668914c20 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -958,6 +958,34 @@ def test_azure_endpoint_detection_flags(): assert llm_other.is_azure_openai_endpoint == False +def test_azure_endpoint_detection_ignores_spoofed_urls(): + """ + Test that endpoint detection does not trust spoofed host/path substrings + """ + with patch.dict(os.environ, { + "AZURE_API_KEY": "test-key", + "AZURE_ENDPOINT": ( + "https://evil.example.com/?redirect=" + "https://test.openai.azure.com/openai/deployments/gpt-4" + ), + }): + llm_query_spoof = LLM(model="azure/gpt-4") + assert llm_query_spoof.is_azure_openai_endpoint == False + assert "model" in llm_query_spoof._prepare_completion_params( + messages=[{"role": "user", "content": "test"}] + ) + + with patch.dict(os.environ, { + "AZURE_API_KEY": "test-key", + "AZURE_ENDPOINT": "https://test.openai.azure.com.evil/openai/deployments/gpt-4", + }): + llm_host_spoof = LLM(model="azure/gpt-4") + assert llm_host_spoof.is_azure_openai_endpoint == False + assert "model" in llm_host_spoof._prepare_completion_params( + messages=[{"role": "user", "content": "test"}] + ) + + def test_azure_improved_error_messages(): """ Test that improved error messages are provided for common HTTP errors