chore: improve typing in task module

This commit is contained in:
Greyson LaLonde
2026-04-13 11:35:25 +08:00
parent 1d6f84c7aa
commit af612b05e0

View File

@@ -45,6 +45,7 @@ from crewai.events.types.task_events import (
TaskStartedEvent, TaskStartedEvent,
) )
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.security import Fingerprint, SecurityConfig from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -301,12 +302,14 @@ class Task(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def validate_required_fields(self) -> Self: def validate_required_fields(self) -> Self:
required_fields = ["description", "expected_output"] if self.description is None:
for field in required_fields: raise ValueError(
if getattr(self, field) is None: "description must be provided either directly or through config"
raise ValueError( )
f"{field} must be provided either directly or through config" if self.expected_output is None:
) raise ValueError(
"expected_output must be provided either directly or through config"
)
return self return self
@model_validator(mode="after") @model_validator(mode="after")
@@ -838,8 +841,8 @@ class Task(BaseModel):
should_inject = self.allow_crewai_trigger_context should_inject = self.allow_crewai_trigger_context
if should_inject and self.agent: if should_inject and self.agent:
crew = getattr(self.agent, "crew", None) crew = self.agent.crew
if crew and hasattr(crew, "_inputs") and crew._inputs: if crew and not isinstance(crew, str) and crew._inputs:
trigger_payload = crew._inputs.get("crewai_trigger_payload") trigger_payload = crew._inputs.get("crewai_trigger_payload")
if trigger_payload is not None: if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}" description += f"\n\nTrigger Payload: {trigger_payload}"
@@ -849,14 +852,11 @@ class Task(BaseModel):
if files: if files:
supported_types: list[str] = [] supported_types: list[str] = []
if ( if (
isinstance(self.agent.llm, BaseLLM) isinstance(self.agent.llm, OpenAICompletion)
and self.agent.llm.supports_multimodal() and self.agent.llm.supports_multimodal()
): ):
provider: str = str( provider: str = self.agent.llm.provider or self.agent.llm.model
getattr(self.agent.llm, "provider", None) api: str | None = self.agent.llm.api
or getattr(self.agent.llm, "model", "openai")
)
api: str | None = getattr(self.agent.llm, "api", None)
supported_types = get_supported_content_types(provider, api) supported_types = get_supported_content_types(provider, api)
def is_auto_injected(content_type: str) -> bool: def is_auto_injected(content_type: str) -> bool: