mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
chore: improve typing in task module
This commit is contained in:
@@ -51,6 +51,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
|
||||
|
||||
try:
|
||||
@@ -75,8 +76,13 @@ try:
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Delta as LiteLLMDelta,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseBase,
|
||||
ModelResponseStream,
|
||||
StreamingChoices as LiteLLMStreamingChoices,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
@@ -85,6 +91,11 @@ except ImportError:
|
||||
LITELLM_AVAILABLE = False
|
||||
litellm = None # type: ignore[assignment]
|
||||
Choices = None # type: ignore[assignment, misc]
|
||||
LiteLLMDelta = None # type: ignore[assignment, misc]
|
||||
Message = None # type: ignore[assignment, misc]
|
||||
ModelResponseBase = None # type: ignore[assignment, misc]
|
||||
ModelResponseStream = None # type: ignore[assignment, misc]
|
||||
LiteLLMStreamingChoices = None # type: ignore[assignment, misc]
|
||||
get_supported_openai_params = None # type: ignore[assignment]
|
||||
ChatCompletionDeltaToolCall = None # type: ignore[assignment, misc]
|
||||
Function = None # type: ignore[assignment, misc]
|
||||
@@ -709,7 +720,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = None
|
||||
response_id = None
|
||||
|
||||
if hasattr(chunk, "id"):
|
||||
if isinstance(chunk, ModelResponseBase):
|
||||
response_id = chunk.id
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
@@ -718,18 +729,16 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
# NOTE: usage is a pydantic extra field on ModelResponseBase,
|
||||
# so it must be accessed via model_extra.
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
elif isinstance(chunk, ModelResponseBase) and chunk.model_extra:
|
||||
usage_info = chunk.model_extra.get("usage") or usage_info
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -738,7 +747,7 @@ class LLM(BaseLLM):
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
elif isinstance(choice, LiteLLMStreamingChoices):
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
@@ -748,7 +757,7 @@ class LLM(BaseLLM):
|
||||
if "content" in delta and delta["content"] is not None:
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
@@ -821,9 +830,8 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
elif isinstance(last_chunk, ModelResponseStream):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -832,14 +840,14 @@ class LLM(BaseLLM):
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
elif isinstance(choice, Choices):
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
elif isinstance(message, Message):
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
@@ -866,24 +874,23 @@ class LLM(BaseLLM):
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
elif isinstance(last_chunk, ModelResponseStream):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = choice.message
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif isinstance(choice, LiteLLMStreamingChoices):
|
||||
delta = choice.delta
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = message.tool_calls
|
||||
if delta:
|
||||
if isinstance(delta, dict) and "tool_calls" in delta:
|
||||
tool_calls = delta["tool_calls"]
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
tool_calls = delta.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
|
||||
@@ -1037,7 +1044,7 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
# Use the usage_info we've been tracking
|
||||
if not usage_info:
|
||||
# Try to get usage from the last chunk if we haven't already
|
||||
@@ -1048,9 +1055,14 @@ class LLM(BaseLLM):
|
||||
and "usage" in last_chunk
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
elif (
|
||||
isinstance(last_chunk, ModelResponseBase)
|
||||
and last_chunk.model_extra
|
||||
):
|
||||
usage_info = (
|
||||
last_chunk.model_extra.get("usage")
|
||||
or usage_info
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -1123,13 +1135,10 @@ class LLM(BaseLLM):
|
||||
params["response_model"] = response_model
|
||||
response = litellm.completion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra:
|
||||
usage_info = response.model_extra.get("usage")
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
@@ -1141,7 +1150,11 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
response_usage = self._usage_to_dict(
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra
|
||||
else None
|
||||
)
|
||||
|
||||
# --- 2) Handle structured output response (when response_model is provided)
|
||||
if response_model is not None:
|
||||
@@ -1166,8 +1179,13 @@ class LLM(BaseLLM):
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
usage_info = (
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase)
|
||||
and response.model_extra
|
||||
else None
|
||||
)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
@@ -1176,7 +1194,7 @@ class LLM(BaseLLM):
|
||||
end_time=0,
|
||||
)
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
tool_calls = response_message.tool_calls or []
|
||||
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
@@ -1269,13 +1287,10 @@ class LLM(BaseLLM):
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra:
|
||||
usage_info = response.model_extra.get("usage")
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
@@ -1287,7 +1302,11 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
response_usage = self._usage_to_dict(
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase) and response.model_extra
|
||||
else None
|
||||
)
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
@@ -1309,8 +1328,13 @@ class LLM(BaseLLM):
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
usage_info = (
|
||||
response.model_extra.get("usage")
|
||||
if isinstance(response, ModelResponseBase)
|
||||
and response.model_extra
|
||||
else None
|
||||
)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
@@ -1319,7 +1343,7 @@ class LLM(BaseLLM):
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
tool_calls = response_message.tool_calls or []
|
||||
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
self._handle_emit_call_events(
|
||||
@@ -1394,18 +1418,19 @@ class LLM(BaseLLM):
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
response_id = chunk.id if hasattr(chunk, "id") else None
|
||||
response_id = chunk.id if isinstance(chunk, ModelResponseBase) else None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
if isinstance(chunk, ModelResponseBase) and chunk.model_extra:
|
||||
chunk_usage = chunk.model_extra.get("usage")
|
||||
if chunk_usage is not None:
|
||||
usage_info = chunk_usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
first_choice = choices[0]
|
||||
@@ -1413,19 +1438,19 @@ class LLM(BaseLLM):
|
||||
|
||||
if isinstance(first_choice, dict):
|
||||
delta = first_choice.get("delta", {})
|
||||
elif hasattr(first_choice, "delta"):
|
||||
elif isinstance(first_choice, LiteLLMStreamingChoices):
|
||||
delta = first_choice.delta
|
||||
|
||||
if delta:
|
||||
if isinstance(delta, dict):
|
||||
chunk_content = delta.get("content")
|
||||
elif hasattr(delta, "content"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
chunk_content = delta.content
|
||||
|
||||
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
|
||||
if isinstance(delta, dict):
|
||||
tool_calls = delta.get("tool_calls")
|
||||
elif hasattr(delta, "tool_calls"):
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
tool_calls = delta.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
@@ -1461,7 +1486,7 @@ class LLM(BaseLLM):
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
@@ -1920,7 +1945,7 @@ class LLM(BaseLLM):
|
||||
return None
|
||||
if isinstance(usage, dict):
|
||||
return usage
|
||||
if hasattr(usage, "model_dump"):
|
||||
if isinstance(usage, BaseModel):
|
||||
result: dict[str, Any] = usage.model_dump()
|
||||
return result
|
||||
if hasattr(usage, "__dict__"):
|
||||
@@ -1984,7 +2009,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return messages
|
||||
|
||||
provider = getattr(self, "provider", None) or self.model
|
||||
provider = self.provider or self.model
|
||||
|
||||
for msg in messages:
|
||||
files = msg.get("files")
|
||||
@@ -2035,7 +2060,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return messages
|
||||
|
||||
provider = getattr(self, "provider", None) or self.model
|
||||
provider = self.provider or self.model
|
||||
|
||||
for msg in messages:
|
||||
files = msg.get("files")
|
||||
|
||||
@@ -45,6 +45,7 @@ from crewai.events.types.task_events import (
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -301,12 +302,14 @@ class Task(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_required_fields(self) -> Self:
|
||||
required_fields = ["description", "expected_output"]
|
||||
for field in required_fields:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
if self.description is None:
|
||||
raise ValueError(
|
||||
"description must be provided either directly or through config"
|
||||
)
|
||||
if self.expected_output is None:
|
||||
raise ValueError(
|
||||
"expected_output must be provided either directly or through config"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -838,8 +841,8 @@ class Task(BaseModel):
|
||||
should_inject = self.allow_crewai_trigger_context
|
||||
|
||||
if should_inject and self.agent:
|
||||
crew = getattr(self.agent, "crew", None)
|
||||
if crew and hasattr(crew, "_inputs") and crew._inputs:
|
||||
crew = self.agent.crew
|
||||
if crew and not isinstance(crew, str) and crew._inputs:
|
||||
trigger_payload = crew._inputs.get("crewai_trigger_payload")
|
||||
if trigger_payload is not None:
|
||||
description += f"\n\nTrigger Payload: {trigger_payload}"
|
||||
@@ -852,11 +855,12 @@ class Task(BaseModel):
|
||||
isinstance(self.agent.llm, BaseLLM)
|
||||
and self.agent.llm.supports_multimodal()
|
||||
):
|
||||
provider: str = str(
|
||||
getattr(self.agent.llm, "provider", None)
|
||||
or getattr(self.agent.llm, "model", "openai")
|
||||
provider: str = self.agent.llm.provider or self.agent.llm.model
|
||||
api: str | None = (
|
||||
self.agent.llm.api
|
||||
if isinstance(self.agent.llm, OpenAICompletion)
|
||||
else None
|
||||
)
|
||||
api: str | None = getattr(self.agent.llm, "api", None)
|
||||
supported_types = get_supported_content_types(provider, api)
|
||||
|
||||
def is_auto_injected(content_type: str) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user