chore: improve typing in task module

This commit is contained in:
Greyson LaLonde
2026-04-13 11:35:25 +08:00
parent 1d6f84c7aa
commit af612b05e0

View File

@@ -45,6 +45,7 @@ from crewai.events.types.task_events import (
TaskStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -301,12 +302,14 @@ class Task(BaseModel):
@model_validator(mode="after")
def validate_required_fields(self) -> Self:
required_fields = ["description", "expected_output"]
for field in required_fields:
if getattr(self, field) is None:
raise ValueError(
f"{field} must be provided either directly or through config"
)
if self.description is None:
raise ValueError(
"description must be provided either directly or through config"
)
if self.expected_output is None:
raise ValueError(
"expected_output must be provided either directly or through config"
)
return self
@model_validator(mode="after")
@@ -838,8 +841,8 @@ class Task(BaseModel):
should_inject = self.allow_crewai_trigger_context
if should_inject and self.agent:
crew = getattr(self.agent, "crew", None)
if crew and hasattr(crew, "_inputs") and crew._inputs:
crew = self.agent.crew
if crew and not isinstance(crew, str) and crew._inputs:
trigger_payload = crew._inputs.get("crewai_trigger_payload")
if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}"
@@ -849,14 +852,11 @@ class Task(BaseModel):
if files:
supported_types: list[str] = []
if (
isinstance(self.agent.llm, BaseLLM)
isinstance(self.agent.llm, OpenAICompletion)
and self.agent.llm.supports_multimodal()
):
provider: str = str(
getattr(self.agent.llm, "provider", None)
or getattr(self.agent.llm, "model", "openai")
)
api: str | None = getattr(self.agent.llm, "api", None)
provider: str = self.agent.llm.provider or self.agent.llm.model
api: str | None = self.agent.llm.api
supported_types = get_supported_content_types(provider, api)
def is_auto_injected(content_type: str) -> bool: