chore: improve typing

This commit is contained in:
Greyson LaLonde
2025-11-11 17:37:08 -05:00
parent 6fb13ee3e0
commit 0803318002
5 changed files with 61 additions and 47 deletions

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import json
import logging
import os
from typing import Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
@@ -20,6 +20,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
try:
from anthropic import Anthropic
from anthropic.types import Message
@@ -48,6 +53,7 @@ class AnthropicCompletion(BaseLLM):
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level
"""
base_url: str | None = Field(
default=None, description="Custom base URL for Anthropic API"
)
@@ -121,8 +127,8 @@ class AnthropicCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Anthropic messages API.
@@ -311,8 +317,8 @@ class AnthropicCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
@@ -399,8 +405,8 @@ class AnthropicCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming message completion."""
@@ -495,8 +501,8 @@ class AnthropicCompletion(BaseLLM):
tool_uses: list[ToolUseBlock],
params: dict[str, Any],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Handle the complete tool use conversation flow.

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
@@ -18,6 +18,8 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
@@ -66,8 +68,6 @@ class AzureCompletion(BaseLLM):
interceptor: HTTP interceptor (not yet supported for Azure)
"""
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
endpoint: str | None = Field(
default=None,
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
@@ -91,7 +91,9 @@ class AzureCompletion(BaseLLM):
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
client: Any = Field(default=None, exclude=True, description="Azure client instance")
_client: ChatCompletionsClient = PrivateAttr(
default_factory=ChatCompletionsClient, # type: ignore[arg-type]
)
_is_openai_model: bool = PrivateAttr(default=False)
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
@@ -190,8 +192,8 @@ class AzureCompletion(BaseLLM):
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Azure AI Inference chat completions API.
@@ -382,8 +384,8 @@ class AzureCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
@@ -478,8 +480,8 @@ class AzureCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion."""

View File

@@ -32,6 +32,9 @@ if TYPE_CHECKING:
ToolTypeDef,
)
from crewai.agent.core import Agent
from crewai.task import Task
try:
from boto3.session import Session
@@ -261,8 +264,8 @@ class BedrockCompletion(BaseLLM):
tools: list[dict[Any, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call AWS Bedrock Converse API."""
@@ -347,8 +350,8 @@ class BedrockCompletion(BaseLLM):
messages: list[dict[str, Any]],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Handle non-streaming converse API call following AWS best practices."""
try:
@@ -528,8 +531,8 @@ class BedrockCompletion(BaseLLM):
messages: list[dict[str, Any]],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Handle streaming converse API call with comprehensive event handling."""
full_response = ""

View File

@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
@@ -16,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
try:
from google import genai # type: ignore[import-untyped]
from google.genai import types # type: ignore[import-untyped]
@@ -196,8 +201,8 @@ class GeminiCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Google Gemini generate content API.
@@ -383,8 +388,8 @@ class GeminiCompletion(BaseLLM):
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
@@ -449,8 +454,8 @@ class GeminiCompletion(BaseLLM):
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming content generation."""

View File

@@ -11,7 +11,7 @@ from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
@@ -74,9 +74,7 @@ class OpenAICompletion(BaseLLM):
)
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
client: OpenAI = Field(
default_factory=OpenAI, exclude=True, description="OpenAI client instance"
)
_client: OpenAI = PrivateAttr(default_factory=OpenAI)
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
@@ -92,7 +90,7 @@ class OpenAICompletion(BaseLLM):
http_client = httpx.Client(transport=transport)
client_config["http_client"] = http_client
self.client = OpenAI(**client_config)
self._client = OpenAI(**client_config)
self.is_o1_model = "o1" in self.model.lower()
self.is_gpt4_model = "gpt-4" in self.model.lower()
@@ -279,14 +277,14 @@ class OpenAICompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
if response_model:
parsed_response = self.client.beta.chat.completions.parse(
parsed_response = self._client.beta.chat.completions.parse(
**params,
response_format=response_model,
)
@@ -310,7 +308,7 @@ class OpenAICompletion(BaseLLM):
)
return structured_json
response: ChatCompletion = self.client.chat.completions.create(**params)
response: ChatCompletion = self._client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response)
@@ -402,8 +400,8 @@ class OpenAICompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion."""
@@ -412,7 +410,7 @@ class OpenAICompletion(BaseLLM):
if response_model:
completion_stream: Iterator[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
self._client.chat.completions.create(**params)
)
accumulated_content = ""
@@ -455,7 +453,7 @@ class OpenAICompletion(BaseLLM):
)
return accumulated_content
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
**params
)