mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
chore: improve typing
This commit is contained in:
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, ClassVar, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
@@ -20,6 +20,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from anthropic import Anthropic
|
from anthropic import Anthropic
|
||||||
from anthropic.types import Message
|
from anthropic.types import Message
|
||||||
@@ -48,6 +53,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
client_params: Additional parameters for the Anthropic client
|
client_params: Additional parameters for the Anthropic client
|
||||||
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_url: str | None = Field(
|
base_url: str | None = Field(
|
||||||
default=None, description="Custom base URL for Anthropic API"
|
default=None, description="Custom base URL for Anthropic API"
|
||||||
)
|
)
|
||||||
@@ -121,8 +127,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call Anthropic messages API.
|
"""Call Anthropic messages API.
|
||||||
@@ -311,8 +317,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming message completion."""
|
"""Handle non-streaming message completion."""
|
||||||
@@ -399,8 +405,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming message completion."""
|
"""Handle streaming message completion."""
|
||||||
@@ -495,8 +501,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle the complete tool use conversation flow.
|
"""Handle the complete tool use conversation flow.
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||||
@@ -18,6 +18,8 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@@ -66,8 +68,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
interceptor: HTTP interceptor (not yet supported for Azure)
|
interceptor: HTTP interceptor (not yet supported for Azure)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
endpoint: str | None = Field(
|
endpoint: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
|
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
|
||||||
@@ -91,7 +91,9 @@ class AzureCompletion(BaseLLM):
|
|||||||
default=None, description="Maximum tokens in response"
|
default=None, description="Maximum tokens in response"
|
||||||
)
|
)
|
||||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||||
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
_client: ChatCompletionsClient = PrivateAttr(
|
||||||
|
default_factory=ChatCompletionsClient, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
_is_openai_model: bool = PrivateAttr(default=False)
|
_is_openai_model: bool = PrivateAttr(default=False)
|
||||||
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
|
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
|
||||||
@@ -190,8 +192,8 @@ class AzureCompletion(BaseLLM):
|
|||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call Azure AI Inference chat completions API.
|
"""Call Azure AI Inference chat completions API.
|
||||||
@@ -382,8 +384,8 @@ class AzureCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming chat completion."""
|
"""Handle non-streaming chat completion."""
|
||||||
@@ -478,8 +480,8 @@ class AzureCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
|
|||||||
ToolTypeDef,
|
ToolTypeDef,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from boto3.session import Session
|
from boto3.session import Session
|
||||||
@@ -261,8 +264,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
tools: list[dict[Any, Any]] | None = None,
|
tools: list[dict[Any, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call AWS Bedrock Converse API."""
|
"""Call AWS Bedrock Converse API."""
|
||||||
@@ -347,8 +350,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
body: BedrockConverseRequestBody,
|
body: BedrockConverseRequestBody,
|
||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle non-streaming converse API call following AWS best practices."""
|
"""Handle non-streaming converse API call following AWS best practices."""
|
||||||
try:
|
try:
|
||||||
@@ -528,8 +531,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
body: BedrockConverseRequestBody,
|
body: BedrockConverseRequestBody,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming converse API call with comprehensive event handling."""
|
"""Handle streaming converse API call with comprehensive event handling."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, ClassVar, cast
|
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@@ -16,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google import genai # type: ignore[import-untyped]
|
from google import genai # type: ignore[import-untyped]
|
||||||
from google.genai import types # type: ignore[import-untyped]
|
from google.genai import types # type: ignore[import-untyped]
|
||||||
@@ -196,8 +201,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call Google Gemini generate content API.
|
"""Call Google Gemini generate content API.
|
||||||
@@ -383,8 +388,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
system_instruction: str | None,
|
system_instruction: str | None,
|
||||||
config: types.GenerateContentConfig,
|
config: types.GenerateContentConfig,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming content generation."""
|
"""Handle non-streaming content generation."""
|
||||||
@@ -449,8 +454,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
contents: list[types.Content],
|
contents: list[types.Content],
|
||||||
config: types.GenerateContentConfig,
|
config: types.GenerateContentConfig,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from openai import APIConnectionError, NotFoundError, OpenAI
|
|||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
@@ -74,9 +74,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
|
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
|
||||||
|
|
||||||
client: OpenAI = Field(
|
_client: OpenAI = PrivateAttr(default_factory=OpenAI)
|
||||||
default_factory=OpenAI, exclude=True, description="OpenAI client instance"
|
|
||||||
)
|
|
||||||
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
|
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
|
||||||
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
|
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
|
||||||
|
|
||||||
@@ -92,7 +90,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
http_client = httpx.Client(transport=transport)
|
http_client = httpx.Client(transport=transport)
|
||||||
client_config["http_client"] = http_client
|
client_config["http_client"] = http_client
|
||||||
|
|
||||||
self.client = OpenAI(**client_config)
|
self._client = OpenAI(**client_config)
|
||||||
|
|
||||||
self.is_o1_model = "o1" in self.model.lower()
|
self.is_o1_model = "o1" in self.model.lower()
|
||||||
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||||
@@ -279,14 +277,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming chat completion."""
|
"""Handle non-streaming chat completion."""
|
||||||
try:
|
try:
|
||||||
if response_model:
|
if response_model:
|
||||||
parsed_response = self.client.beta.chat.completions.parse(
|
parsed_response = self._client.beta.chat.completions.parse(
|
||||||
**params,
|
**params,
|
||||||
response_format=response_model,
|
response_format=response_model,
|
||||||
)
|
)
|
||||||
@@ -310,7 +308,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
return structured_json
|
return structured_json
|
||||||
|
|
||||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||||
|
|
||||||
usage = self._extract_openai_token_usage(response)
|
usage = self._extract_openai_token_usage(response)
|
||||||
|
|
||||||
@@ -402,8 +400,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
@@ -412,7 +410,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||||
self.client.chat.completions.create(**params)
|
self._client.chat.completions.create(**params)
|
||||||
)
|
)
|
||||||
|
|
||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
@@ -455,7 +453,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
return accumulated_content
|
return accumulated_content
|
||||||
|
|
||||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
|
||||||
**params
|
**params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user