mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
chore: improve typing
This commit is contained in:
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -20,6 +20,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
@@ -48,6 +53,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
||||
"""
|
||||
|
||||
base_url: str | None = Field(
|
||||
default=None, description="Custom base URL for Anthropic API"
|
||||
)
|
||||
@@ -121,8 +127,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Anthropic messages API.
|
||||
@@ -311,8 +317,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
@@ -399,8 +405,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
@@ -495,8 +501,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete tool use conversation flow.
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
@@ -18,6 +18,8 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@@ -66,8 +68,6 @@ class AzureCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor (not yet supported for Azure)
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
|
||||
@@ -91,7 +91,9 @@ class AzureCompletion(BaseLLM):
|
||||
default=None, description="Maximum tokens in response"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
||||
_client: ChatCompletionsClient = PrivateAttr(
|
||||
default_factory=ChatCompletionsClient, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
_is_openai_model: bool = PrivateAttr(default=False)
|
||||
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
|
||||
@@ -190,8 +192,8 @@ class AzureCompletion(BaseLLM):
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Azure AI Inference chat completions API.
|
||||
@@ -382,8 +384,8 @@ class AzureCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
@@ -478,8 +480,8 @@ class AzureCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
|
||||
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from boto3.session import Session
|
||||
@@ -261,8 +264,8 @@ class BedrockCompletion(BaseLLM):
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call AWS Bedrock Converse API."""
|
||||
@@ -347,8 +350,8 @@ class BedrockCompletion(BaseLLM):
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Handle non-streaming converse API call following AWS best practices."""
|
||||
try:
|
||||
@@ -528,8 +531,8 @@ class BedrockCompletion(BaseLLM):
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
full_response = ""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
@@ -16,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
@@ -196,8 +201,8 @@ class GeminiCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Google Gemini generate content API.
|
||||
@@ -383,8 +388,8 @@ class GeminiCompletion(BaseLLM):
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
@@ -449,8 +454,8 @@ class GeminiCompletion(BaseLLM):
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
|
||||
@@ -11,7 +11,7 @@ from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -74,9 +74,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
|
||||
|
||||
client: OpenAI = Field(
|
||||
default_factory=OpenAI, exclude=True, description="OpenAI client instance"
|
||||
)
|
||||
_client: OpenAI = PrivateAttr(default_factory=OpenAI)
|
||||
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
|
||||
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
|
||||
|
||||
@@ -92,7 +90,7 @@ class OpenAICompletion(BaseLLM):
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
self._client = OpenAI(**client_config)
|
||||
|
||||
self.is_o1_model = "o1" in self.model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||
@@ -279,14 +277,14 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
**params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -310,7 +308,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -402,8 +400,8 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
@@ -412,7 +410,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
accumulated_content = ""
|
||||
@@ -455,7 +453,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user