diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index e6f5cc68b..db126954e 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -51,6 +51,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( ) from crewai.utilities.logger_utils import suppress_warnings from crewai.utilities.string_utils import sanitize_tool_name +from crewai.utilities.token_counter_callback import TokenCalcHandler try: @@ -75,8 +76,13 @@ try: from litellm.types.utils import ( ChatCompletionDeltaToolCall, Choices, + Delta as LiteLLMDelta, Function, + Message, ModelResponse, + ModelResponseBase, + ModelResponseStream, + StreamingChoices as LiteLLMStreamingChoices, ) from litellm.utils import supports_response_schema @@ -85,6 +91,11 @@ except ImportError: LITELLM_AVAILABLE = False litellm = None # type: ignore[assignment] Choices = None # type: ignore[assignment, misc] + LiteLLMDelta = None # type: ignore[assignment, misc] + Message = None # type: ignore[assignment, misc] + ModelResponseBase = None # type: ignore[assignment, misc] + ModelResponseStream = None # type: ignore[assignment, misc] + LiteLLMStreamingChoices = None # type: ignore[assignment, misc] get_supported_openai_params = None # type: ignore[assignment] ChatCompletionDeltaToolCall = None # type: ignore[assignment, misc] Function = None # type: ignore[assignment, misc] @@ -709,7 +720,7 @@ class LLM(BaseLLM): chunk_content = None response_id = None - if hasattr(chunk, "id"): + if isinstance(chunk, ModelResponseBase): response_id = chunk.id # Safely extract content from various chunk formats @@ -718,18 +729,16 @@ class LLM(BaseLLM): choices = None if isinstance(chunk, dict) and "choices" in chunk: choices = chunk["choices"] - elif hasattr(chunk, "choices"): - # Check if choices is not a type but an actual attribute with value - if not isinstance(chunk.choices, type): - choices = chunk.choices + elif isinstance(chunk, ModelResponseStream): + choices = chunk.choices # Try to extract usage information if available + # NOTE: usage is a pydantic extra field on ModelResponseBase, + # so it must be accessed via model_extra. if isinstance(chunk, dict) and "usage" in chunk: usage_info = chunk["usage"] - elif hasattr(chunk, "usage"): - # Check if usage is not a type but an actual attribute with value - if not isinstance(chunk.usage, type): - usage_info = chunk.usage + elif isinstance(chunk, ModelResponseBase) and chunk.model_extra: + usage_info = chunk.model_extra.get("usage") or usage_info if choices and len(choices) > 0: choice = choices[0] @@ -738,7 +747,7 @@ class LLM(BaseLLM): delta = None if isinstance(choice, dict) and "delta" in choice: delta = choice["delta"] - elif hasattr(choice, "delta"): + elif isinstance(choice, LiteLLMStreamingChoices): delta = choice.delta # Extract content from delta @@ -748,7 +757,7 @@ class LLM(BaseLLM): if "content" in delta and delta["content"] is not None: chunk_content = delta["content"] # Handle object format - elif hasattr(delta, "content"): + elif isinstance(delta, LiteLLMDelta): chunk_content = delta.content # Handle case where content might be None or empty @@ -821,9 +830,8 @@ class LLM(BaseLLM): choices = None if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] - elif hasattr(last_chunk, "choices"): - if not isinstance(last_chunk.choices, type): - choices = last_chunk.choices + elif isinstance(last_chunk, ModelResponseStream): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] @@ -832,14 +840,14 @@ class LLM(BaseLLM): message = None if isinstance(choice, dict) and "message" in choice: message = choice["message"] - elif hasattr(choice, "message"): + elif isinstance(choice, Choices): message = choice.message if message: content = None if isinstance(message, dict) and "content" in message: content = message["content"] - elif hasattr(message, "content"): + elif isinstance(message, Message): content = message.content if content: @@ -866,24 +874,23 @@ class LLM(BaseLLM): choices = None if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] - elif hasattr(last_chunk, "choices"): - if not isinstance(last_chunk.choices, type): - choices = last_chunk.choices + elif isinstance(last_chunk, ModelResponseStream): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] - message = None - if isinstance(choice, dict) and "message" in choice: - message = choice["message"] - elif hasattr(choice, "message"): - message = choice.message + delta = None + if isinstance(choice, dict) and "delta" in choice: + delta = choice["delta"] + elif isinstance(choice, LiteLLMStreamingChoices): + delta = choice.delta - if message: - if isinstance(message, dict) and "tool_calls" in message: - tool_calls = message["tool_calls"] - elif hasattr(message, "tool_calls"): - tool_calls = message.tool_calls + if delta: + if isinstance(delta, dict) and "tool_calls" in delta: + tool_calls = delta["tool_calls"] + elif isinstance(delta, LiteLLMDelta): + tool_calls = delta.tool_calls except Exception as e: logging.debug(f"Error checking for tool calls: {e}") @@ -1037,7 +1044,7 @@ class LLM(BaseLLM): """ if callbacks and len(callbacks) > 0: for callback in callbacks: - if hasattr(callback, "log_success_event"): + if isinstance(callback, TokenCalcHandler): # Use the usage_info we've been tracking if not usage_info: # Try to get usage from the last chunk if we haven't already @@ -1048,9 +1055,14 @@ class LLM(BaseLLM): and "usage" in last_chunk ): usage_info = last_chunk["usage"] - elif hasattr(last_chunk, "usage"): - if not isinstance(last_chunk.usage, type): - usage_info = last_chunk.usage + elif ( + isinstance(last_chunk, ModelResponseBase) + and last_chunk.model_extra + ): + usage_info = ( + last_chunk.model_extra.get("usage") + or usage_info + ) except Exception as e: logging.debug(f"Error extracting usage info: {e}") @@ -1123,13 +1135,10 @@ class LLM(BaseLLM): params["response_model"] = response_model response = litellm.completion(**params) - if ( - hasattr(response, "usage") - and not isinstance(response.usage, type) - and response.usage - ): - usage_info = response.usage - self._track_token_usage_internal(usage_info) + if isinstance(response, ModelResponseBase) and response.model_extra: + usage_info = response.model_extra.get("usage") + if usage_info: + self._track_token_usage_internal(usage_info) except LLMContextLengthExceededError: # Re-raise our own context length error @@ -1141,7 +1150,11 @@ class LLM(BaseLLM): raise LLMContextLengthExceededError(error_msg) from e raise - response_usage = self._usage_to_dict(getattr(response, "usage", None)) + response_usage = self._usage_to_dict( + response.model_extra.get("usage") + if isinstance(response, ModelResponseBase) and response.model_extra + else None + ) # --- 2) Handle structured output response (when response_model is provided) if response_model is not None: @@ -1166,8 +1179,13 @@ class LLM(BaseLLM): # --- 3) Handle callbacks with usage info if callbacks and len(callbacks) > 0: for callback in callbacks: - if hasattr(callback, "log_success_event"): - usage_info = getattr(response, "usage", None) + if isinstance(callback, TokenCalcHandler): + usage_info = ( + response.model_extra.get("usage") + if isinstance(response, ModelResponseBase) + and response.model_extra + else None + ) if usage_info: callback.log_success_event( kwargs=params, @@ -1176,7 +1194,7 @@ class LLM(BaseLLM): end_time=0, ) # --- 4) Check for tool calls - tool_calls = getattr(response_message, "tool_calls", []) + tool_calls = response_message.tool_calls or [] # --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response if (not tool_calls or not available_functions) and text_response: @@ -1269,13 +1287,10 @@ class LLM(BaseLLM): params["response_model"] = response_model response = await litellm.acompletion(**params) - if ( - hasattr(response, "usage") - and not isinstance(response.usage, type) - and response.usage - ): - usage_info = response.usage - self._track_token_usage_internal(usage_info) + if isinstance(response, ModelResponseBase) and response.model_extra: + usage_info = response.model_extra.get("usage") + if usage_info: + self._track_token_usage_internal(usage_info) except LLMContextLengthExceededError: # Re-raise our own context length error @@ -1287,7 +1302,11 @@ class LLM(BaseLLM): raise LLMContextLengthExceededError(error_msg) from e raise - response_usage = self._usage_to_dict(getattr(response, "usage", None)) + response_usage = self._usage_to_dict( + response.model_extra.get("usage") + if isinstance(response, ModelResponseBase) and response.model_extra + else None + ) if response_model is not None: if isinstance(response, BaseModel): @@ -1309,8 +1328,13 @@ class LLM(BaseLLM): if callbacks and len(callbacks) > 0: for callback in callbacks: - if hasattr(callback, "log_success_event"): - usage_info = getattr(response, "usage", None) + if isinstance(callback, TokenCalcHandler): + usage_info = ( + response.model_extra.get("usage") + if isinstance(response, ModelResponseBase) + and response.model_extra + else None + ) if usage_info: callback.log_success_event( kwargs=params, @@ -1319,7 +1343,7 @@ class LLM(BaseLLM): end_time=0, ) - tool_calls = getattr(response_message, "tool_calls", []) + tool_calls = response_message.tool_calls or [] if (not tool_calls or not available_functions) and text_response: self._handle_emit_call_events( @@ -1394,18 +1418,19 @@ class LLM(BaseLLM): async for chunk in await litellm.acompletion(**params): chunk_count += 1 chunk_content = None - response_id = chunk.id if hasattr(chunk, "id") else None + response_id = chunk.id if isinstance(chunk, ModelResponseBase) else None try: choices = None if isinstance(chunk, dict) and "choices" in chunk: choices = chunk["choices"] - elif hasattr(chunk, "choices"): - if not isinstance(chunk.choices, type): - choices = chunk.choices + elif isinstance(chunk, ModelResponseStream): + choices = chunk.choices - if hasattr(chunk, "usage") and chunk.usage is not None: - usage_info = chunk.usage + if isinstance(chunk, ModelResponseBase) and chunk.model_extra: + chunk_usage = chunk.model_extra.get("usage") + if chunk_usage is not None: + usage_info = chunk_usage if choices and len(choices) > 0: first_choice = choices[0] @@ -1413,19 +1438,19 @@ class LLM(BaseLLM): if isinstance(first_choice, dict): delta = first_choice.get("delta", {}) - elif hasattr(first_choice, "delta"): + elif isinstance(first_choice, LiteLLMStreamingChoices): delta = first_choice.delta if delta: if isinstance(delta, dict): chunk_content = delta.get("content") - elif hasattr(delta, "content"): + elif isinstance(delta, LiteLLMDelta): chunk_content = delta.content tool_calls: list[ChatCompletionDeltaToolCall] | None = None if isinstance(delta, dict): tool_calls = delta.get("tool_calls") - elif hasattr(delta, "tool_calls"): + elif isinstance(delta, LiteLLMDelta): tool_calls = delta.tool_calls if tool_calls: @@ -1461,7 +1486,7 @@ class LLM(BaseLLM): if callbacks and len(callbacks) > 0 and usage_info: for callback in callbacks: - if hasattr(callback, "log_success_event"): + if isinstance(callback, TokenCalcHandler): callback.log_success_event( kwargs=params, response_obj={"usage": usage_info}, @@ -1920,7 +1945,7 @@ class LLM(BaseLLM): return None if isinstance(usage, dict): return usage - if hasattr(usage, "model_dump"): + if isinstance(usage, BaseModel): result: dict[str, Any] = usage.model_dump() return result if hasattr(usage, "__dict__"): @@ -1984,7 +2009,7 @@ class LLM(BaseLLM): ) return messages - provider = getattr(self, "provider", None) or self.model + provider = self.provider or self.model for msg in messages: files = msg.get("files") @@ -2035,7 +2060,7 @@ class LLM(BaseLLM): ) return messages - provider = getattr(self, "provider", None) or self.model + provider = self.provider or self.model for msg in messages: files = msg.get("files") diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 0a159cb0e..e12caa2af 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -45,6 +45,7 @@ from crewai.events.types.task_events import ( TaskStartedEvent, ) from crewai.llms.base_llm import BaseLLM +from crewai.llms.providers.openai.completion import OpenAICompletion from crewai.security import Fingerprint, SecurityConfig from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput @@ -301,12 +302,14 @@ class Task(BaseModel): @model_validator(mode="after") def validate_required_fields(self) -> Self: - required_fields = ["description", "expected_output"] - for field in required_fields: - if getattr(self, field) is None: - raise ValueError( - f"{field} must be provided either directly or through config" - ) + if self.description is None: + raise ValueError( + "description must be provided either directly or through config" + ) + if self.expected_output is None: + raise ValueError( + "expected_output must be provided either directly or through config" + ) return self @model_validator(mode="after") @@ -838,8 +841,8 @@ class Task(BaseModel): should_inject = self.allow_crewai_trigger_context if should_inject and self.agent: - crew = getattr(self.agent, "crew", None) - if crew and hasattr(crew, "_inputs") and crew._inputs: + crew = self.agent.crew + if crew and not isinstance(crew, str) and crew._inputs: trigger_payload = crew._inputs.get("crewai_trigger_payload") if trigger_payload is not None: description += f"\n\nTrigger Payload: {trigger_payload}" @@ -852,11 +855,12 @@ class Task(BaseModel): isinstance(self.agent.llm, BaseLLM) and self.agent.llm.supports_multimodal() ): - provider: str = str( - getattr(self.agent.llm, "provider", None) - or getattr(self.agent.llm, "model", "openai") + provider: str = self.agent.llm.provider or self.agent.llm.model + api: str | None = ( + self.agent.llm.api + if isinstance(self.agent.llm, OpenAICompletion) + else None ) - api: str | None = getattr(self.agent.llm, "api", None) supported_types = get_supported_content_types(provider, api) def is_auto_injected(content_type: str) -> bool: