From 08033180028b7c79888599dcad24d422e1958313 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 11 Nov 2025 17:37:08 -0500 Subject: [PATCH] chore: improve typing --- .../llm/providers/anthropic/completion.py | 26 ++++++++++++------- .../crewai/llm/providers/azure/completion.py | 24 +++++++++-------- .../llm/providers/bedrock/completion.py | 15 ++++++----- .../crewai/llm/providers/gemini/completion.py | 19 +++++++++----- .../crewai/llm/providers/openai/completion.py | 24 ++++++++--------- 5 files changed, 61 insertions(+), 47 deletions(-) diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index 05366bbf4..8678075d3 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import logging import os -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -20,6 +20,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.types import LLMMessage +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + + try: from anthropic import Anthropic from anthropic.types import Message @@ -48,6 +53,7 @@ class AnthropicCompletion(BaseLLM): client_params: Additional parameters for the Anthropic client interceptor: HTTP interceptor for modifying requests/responses at transport level """ + base_url: str | None = Field( default=None, description="Custom base URL for Anthropic API" ) @@ -121,8 +127,8 @@ class AnthropicCompletion(BaseLLM): tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Anthropic messages API. @@ -311,8 +317,8 @@ class AnthropicCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming message completion.""" @@ -399,8 +405,8 @@ class AnthropicCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming message completion.""" @@ -495,8 +501,8 @@ class AnthropicCompletion(BaseLLM): tool_uses: list[ToolUseBlock], params: dict[str, Any], available_functions: dict[str, Any], - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle the complete tool use conversation flow. diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index 3a4f68d08..8ad8eb783 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import logging import os -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES @@ -18,6 +18,8 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -66,8 +68,6 @@ class AzureCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Azure) """ - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - endpoint: str | None = Field( default=None, description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", @@ -91,7 +91,9 @@ class AzureCompletion(BaseLLM): default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - client: Any = Field(default=None, exclude=True, description="Azure client instance") + _client: ChatCompletionsClient = PrivateAttr( + default_factory=ChatCompletionsClient, # type: ignore[arg-type] + ) _is_openai_model: bool = PrivateAttr(default=False) _is_azure_openai_endpoint: bool = PrivateAttr(default=False) @@ -190,8 +192,8 @@ class AzureCompletion(BaseLLM): tools: list[dict[str, BaseTool]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Azure AI Inference chat completions API. @@ -382,8 +384,8 @@ class AzureCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming chat completion.""" @@ -478,8 +480,8 @@ class AzureCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming chat completion.""" diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 58495a151..935eb3432 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -32,6 +32,9 @@ if TYPE_CHECKING: ToolTypeDef, ) + from crewai.agent.core import Agent + from crewai.task import Task + try: from boto3.session import Session @@ -261,8 +264,8 @@ class BedrockCompletion(BaseLLM): tools: list[dict[Any, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call AWS Bedrock Converse API.""" @@ -347,8 +350,8 @@ class BedrockCompletion(BaseLLM): messages: list[dict[str, Any]], body: BedrockConverseRequestBody, available_functions: Mapping[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle non-streaming converse API call following AWS best practices.""" try: @@ -528,8 +531,8 @@ class BedrockCompletion(BaseLLM): messages: list[dict[str, Any]], body: BedrockConverseRequestBody, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle streaming converse API call with comprehensive event handling.""" full_response = "" diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index f2b191656..5adfa3863 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -16,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.types import LLMMessage +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + + try: from google import genai # type: ignore[import-untyped] from google.genai import types # type: ignore[import-untyped] @@ -196,8 +201,8 @@ class GeminiCompletion(BaseLLM): tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Google Gemini generate content API. @@ -383,8 +388,8 @@ class GeminiCompletion(BaseLLM): system_instruction: str | None, config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming content generation.""" @@ -449,8 +454,8 @@ class GeminiCompletion(BaseLLM): contents: list[types.Content], config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming content generation.""" diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index b9fcc99c7..d557f7dc5 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -11,7 +11,7 @@ from openai import APIConnectionError, NotFoundError, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -74,9 +74,7 @@ class OpenAICompletion(BaseLLM): ) reasoning_effort: str | None = Field(None, description="Reasoning effort level") - client: OpenAI = Field( - default_factory=OpenAI, exclude=True, description="OpenAI client instance" - ) + _client: OpenAI = PrivateAttr(default_factory=OpenAI) is_o1_model: bool = Field(False, description="Whether this is an O1 model") is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") @@ -92,7 +90,7 @@ class OpenAICompletion(BaseLLM): http_client = httpx.Client(transport=transport) client_config["http_client"] = http_client - self.client = OpenAI(**client_config) + self._client = OpenAI(**client_config) self.is_o1_model = "o1" in self.model.lower() self.is_gpt4_model = "gpt-4" in self.model.lower() @@ -279,14 +277,14 @@ class OpenAICompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming chat completion.""" try: if response_model: - parsed_response = self.client.beta.chat.completions.parse( + parsed_response = self._client.beta.chat.completions.parse( **params, response_format=response_model, ) @@ -310,7 +308,7 @@ class OpenAICompletion(BaseLLM): ) return structured_json - response: ChatCompletion = self.client.chat.completions.create(**params) + response: ChatCompletion = self._client.chat.completions.create(**params) usage = self._extract_openai_token_usage(response) @@ -402,8 +400,8 @@ class OpenAICompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming chat completion.""" @@ -412,7 +410,7 @@ class OpenAICompletion(BaseLLM): if response_model: completion_stream: Iterator[ChatCompletionChunk] = ( - self.client.chat.completions.create(**params) + self._client.chat.completions.create(**params) ) accumulated_content = "" @@ -455,7 +453,7 @@ class OpenAICompletion(BaseLLM): ) return accumulated_content - stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create( + stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create( **params )