diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 94ddfe31e..4144f9ab4 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -98,6 +98,9 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +cerebras = [ + "cerebras-cloud-sdk~=1.67.0", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 52e3b0b9f..7e4f89a91 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -606,6 +606,13 @@ class LLM(BaseLLM): return BedrockCompletion + if provider == "cerebras": + from crewai.llms.providers.cerebras.completion import ( + CerebrasCompletion, + ) + + return CerebrasCompletion + # OpenAI-compatible providers openai_compatible_providers = { "openrouter", diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py index 260c23daf..23735c08a 100644 --- a/lib/crewai/src/crewai/llms/constants.py +++ b/lib/crewai/src/crewai/llms/constants.py @@ -523,6 +523,20 @@ BedrockModels: TypeAlias = Literal[ "qwen.qwen3-coder-30b-a3b-v1:0", "twelvelabs.pegasus-1-2-v1:0", ] +CerebrasModels: TypeAlias = Literal[ + "llama3.1-8b", + "gpt-oss-120b", + "qwen-3-235b-a22b-instruct-2507", + "zai-glm-4.7", +] +CEREBRAS_MODELS: list[CerebrasModels] = [ + "llama3.1-8b", + "gpt-oss-120b", + "qwen-3-235b-a22b-instruct-2507", + "zai-glm-4.7", +] + + BEDROCK_MODELS: list[BedrockModels] = [ # Inference profiles (regional) - Claude 4 "us.anthropic.claude-sonnet-4-5-20250929-v1:0", diff --git a/lib/crewai/src/crewai/llms/providers/cerebras/__init__.py b/lib/crewai/src/crewai/llms/providers/cerebras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/src/crewai/llms/providers/cerebras/completion.py b/lib/crewai/src/crewai/llms/providers/cerebras/completion.py new file mode 100644 index 000000000..443cb46cb --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/cerebras/completion.py @@ -0,0 +1,1148 @@ +"""Cerebras native completion implementation. + +Uses the official ``cerebras-cloud-sdk`` (:class:`Cerebras` / :class:`AsyncCerebras`) +directly for ``chat.completions.create``. This class subclasses +:class:`~crewai.llms.base_llm.BaseLLM` only — it does not inherit from the OpenAI +provider — while following the same chat-completion request shape the Cerebras API +expects (OpenAI-compatible HTTP surface). + +Install the optional dependency: ``uv add "crewai[cerebras]"``. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +import json +import logging +import os +from typing import TYPE_CHECKING, Any, Literal + +import httpx +from pydantic import BaseModel, PrivateAttr, model_validator + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context +from crewai.llms.hooks.base import BaseInterceptor +from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport +from crewai.llms.providers.utils.common import safe_tool_conversion +from crewai.utilities.agent_utils import is_context_length_exceeded +from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededError, +) +from crewai.utilities.pydantic_schema_utils import ( + generate_model_description, + sanitize_tool_params_for_openai_strict, +) +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agents.agent_builder.base_agent import BaseAgent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + +try: + from cerebras.cloud.sdk import ( + APIConnectionError, + AsyncCerebras, + Cerebras, + NotFoundError, + ) +except ImportError: + raise ImportError( + 'Cerebras native provider not available, to install: uv add "crewai[cerebras]"' + ) from None + + +CEREBRAS_BASE_URL_ENV = "CEREBRAS_BASE_URL" +CEREBRAS_API_KEY_ENV = "CEREBRAS_API_KEY" + + +def _extract_chat_usage(response: Any) -> dict[str, Any]: + """Best-effort usage extraction; works for Cerebras and OpenAI-shaped responses.""" + if hasattr(response, "usage") and response.usage: + usage = response.usage + result: dict[str, Any] = { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + prompt_details = getattr(usage, "prompt_tokens_details", None) + if prompt_details: + result["cached_prompt_tokens"] = ( + getattr(prompt_details, "cached_tokens", 0) or 0 + ) + completion_details = getattr(usage, "completion_tokens_details", None) + if completion_details: + result["reasoning_tokens"] = ( + getattr(completion_details, "reasoning_tokens", 0) or 0 + ) + return result + return {"total_tokens": 0} + + +def _first_tool_call_function(message: Any) -> tuple[str, dict[str, Any]] | None: + """Resolve the first function tool call using duck typing (Cerebras SDK types).""" + tool_calls = getattr(message, "tool_calls", None) or [] + if not tool_calls: + return None + tc = tool_calls[0] + fn = getattr(tc, "function", None) + if fn is None: + return None + name = getattr(fn, "name", None) or "" + if not name: + return None + raw_args = getattr(fn, "arguments", None) or "{}" + try: + args: dict[str, Any] = ( + json.loads(raw_args) if isinstance(raw_args, str) else dict(raw_args) + ) + except (json.JSONDecodeError, TypeError, ValueError): + args = {} + return name, args + + +class CerebrasCompletion(BaseLLM): + """Cerebras chat completions via ``cerebras-cloud-sdk``. + + Reads ``CEREBRAS_API_KEY`` and optional ``CEREBRAS_BASE_URL``. Only the chat + completions API is supported (``api`` is always ``\"completions\"``). + """ + + llm_type: Literal["cerebras"] = "cerebras" + + timeout: float | None = None + max_retries: int = 2 + default_headers: dict[str, str] | None = None + default_query: dict[str, Any] | None = None + client_params: dict[str, Any] | None = None + top_p: float | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + max_tokens: int | None = None + max_completion_tokens: int | None = None + seed: int | None = None + stream: bool = False + response_format: JsonResponseFormat | type[BaseModel] | None = None + logprobs: bool | None = None + top_logprobs: int | None = None + reasoning_effort: Literal["none", "low", "medium", "high"] | None = None + interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None + api: Literal["completions"] = "completions" + api_base: str | None = None + + service_tier: Literal["priority", "default", "auto", "flex"] | None = None + prompt_cache_key: str | None = None + clear_thinking: bool | None = None + + _client: Any = PrivateAttr(default=None) + _async_client: Any = PrivateAttr(default=None) + + @model_validator(mode="before") + @classmethod + def _normalize_cerebras_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data["provider"] = "cerebras" + data["api_key"] = data.get("api_key") or os.getenv(CEREBRAS_API_KEY_ENV) + if not data.get("base_url") and not data.get("api_base"): + env_base_url = os.getenv(CEREBRAS_BASE_URL_ENV) + if env_base_url: + data["base_url"] = env_base_url + if "api_base" not in data: + data["api_base"] = None + data["api"] = "completions" + return data + + @model_validator(mode="after") + def _init_clients(self) -> CerebrasCompletion: + try: + self._client = self._build_sync_client() + self._async_client = self._build_async_client() + except ValueError: + pass + return self + + def _get_client_params(self) -> dict[str, Any]: + if self.api_key is None: + self.api_key = os.getenv(CEREBRAS_API_KEY_ENV) + if self.api_key is None: + raise ValueError( + "CEREBRAS_API_KEY is required. Set it in the environment " + "or pass api_key= when constructing the LLM." + ) + + base_url = self.base_url or self.api_base or os.getenv(CEREBRAS_BASE_URL_ENV) + + base_params: dict[str, Any] = { + "api_key": self.api_key, + "timeout": self.timeout, + "max_retries": self.max_retries, + "default_headers": self.default_headers, + "default_query": self.default_query, + } + if base_url: + base_params["base_url"] = base_url + + client_params = {k: v for k, v in base_params.items() if v is not None} + + if self.client_params: + client_params.update(self.client_params) + + return client_params + + def _build_sync_client(self) -> Any: + client_config = self._get_client_params() + if self.interceptor: + transport = HTTPTransport(interceptor=self.interceptor) + client_config["http_client"] = httpx.Client(transport=transport) + return Cerebras(**client_config) + + def _build_async_client(self) -> Any: + client_config = self._get_client_params() + if self.interceptor: + transport = AsyncHTTPTransport(interceptor=self.interceptor) + client_config["http_client"] = httpx.AsyncClient(transport=transport) + return AsyncCerebras(**client_config) + + def _get_sync_client(self) -> Any: + if self._client is None: + self._client = self._build_sync_client() + return self._client + + def _get_async_client(self) -> Any: + if self._async_client is None: + self._async_client = self._build_async_client() + return self._async_client + + def to_config_dict(self) -> dict[str, Any]: + config = super().to_config_dict() + if self.timeout is not None: + config["timeout"] = self.timeout + if self.max_retries != 2: + config["max_retries"] = self.max_retries + if self.top_p is not None: + config["top_p"] = self.top_p + if self.frequency_penalty is not None: + config["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + config["presence_penalty"] = self.presence_penalty + if self.max_tokens is not None: + config["max_tokens"] = self.max_tokens + if self.max_completion_tokens is not None: + config["max_completion_tokens"] = self.max_completion_tokens + if self.seed is not None: + config["seed"] = self.seed + if self.reasoning_effort is not None: + config["reasoning_effort"] = self.reasoning_effort + if self.stream: + config["stream"] = True + if self.service_tier is not None: + config["service_tier"] = self.service_tier + if self.prompt_cache_key is not None: + config["prompt_cache_key"] = self.prompt_cache_key + if self.clear_thinking is not None: + config["clear_thinking"] = self.clear_thinking + return config + + def _convert_tools_for_interference( + self, tools: list[dict[str, BaseTool]] + ) -> list[dict[str, Any]]: + openai_tools: list[dict[str, Any]] = [] + + for tool in tools: + name, description, parameters = safe_tool_conversion(tool, "OpenAI") + + openai_tool: dict[str, Any] = { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": True, + }, + } + + if parameters: + params_dict = ( + parameters if isinstance(parameters, dict) else dict(parameters) + ) + openai_tool["function"]["parameters"] = ( + sanitize_tool_params_for_openai_strict(params_dict) + ) + + openai_tools.append(openai_tool) + return openai_tools + + def _prepare_completion_params( + self, + messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "model": self.model, + "messages": messages, + } + if self.stream: + params["stream"] = self.stream + params["stream_options"] = {"include_usage": True} + + params.update(self.additional_params) + + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + if self.frequency_penalty is not None: + params["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + params["presence_penalty"] = self.presence_penalty + if self.max_completion_tokens is not None: + params["max_completion_tokens"] = self.max_completion_tokens + elif self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.seed is not None: + params["seed"] = self.seed + if self.logprobs is not None: + params["logprobs"] = self.logprobs + if self.top_logprobs is not None: + params["top_logprobs"] = self.top_logprobs + + if self.response_format is not None: + if isinstance(self.response_format, type) and issubclass( + self.response_format, BaseModel + ): + params["response_format"] = generate_model_description( + self.response_format + ) + elif isinstance(self.response_format, dict): + params["response_format"] = self.response_format + + if tools: + params["tools"] = self._convert_tools_for_interference(tools) + params["tool_choice"] = "auto" + + if self.reasoning_effort is not None: + params["reasoning_effort"] = self.reasoning_effort + if self.service_tier is not None: + params["service_tier"] = self.service_tier + if self.prompt_cache_key is not None: + params["prompt_cache_key"] = self.prompt_cache_key + if self.clear_thinking is not None: + params["clear_thinking"] = self.clear_thinking + + crewai_specific_params = { + "callbacks", + "available_functions", + "from_task", + "from_agent", + "provider", + "api_key", + "base_url", + "api_base", + "timeout", + } + + return {k: v for k, v in params.items() if k not in crewai_specific_params} + + def _finalize_streaming_response( + self, + full_response: str, + tool_calls: dict[int, dict[str, Any]], + usage_data: dict[str, Any] | None, + params: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str | list[dict[str, Any]]: + if usage_data: + self._track_token_usage_internal(usage_data) + + if tool_calls and not available_functions: + tool_calls_list = [ + { + "id": call_data["id"], + "type": "function", + "function": { + "name": call_data["name"], + "arguments": call_data["arguments"], + }, + "index": call_data["index"], + } + for call_data in tool_calls.values() + ] + self._emit_call_completed_event( + response=tool_calls_list, + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + return tool_calls_list + + if tool_calls and available_functions: + for call_data in tool_calls.values(): + function_name = call_data["name"] + arguments = call_data["arguments"] + + if not function_name or not arguments: + continue + + if function_name not in available_functions: + logging.warning( + f"Function '{function_name}' not found in available functions" + ) + continue + + try: + function_args = json.loads(arguments) + except json.JSONDecodeError as e: + logging.error(f"Failed to parse streamed tool arguments: {e}") + continue + + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if result is not None: + return result + + full_response = self._apply_stop_words(full_response) + + self._emit_call_completed_event( + response=full_response, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + + return full_response + + def _handle_completion( + self, + params: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + try: + if response_model: + structured_params = { + k: v for k, v in params.items() if k != "response_format" + } + structured_params["response_format"] = generate_model_description( + response_model + ) + parsed_response = self._get_sync_client().chat.completions.create( + **structured_params + ) + usage = _extract_chat_usage(parsed_response) + self._track_token_usage_internal(usage) + message = parsed_response.choices[0].message + content = getattr(message, "content", None) or "" + structured = self._validate_structured_output(content, response_model) + self._emit_call_completed_event( + response=structured.model_dump_json() + if isinstance(structured, BaseModel) + else structured, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return structured + + response = self._get_sync_client().chat.completions.create(**params) + + usage = _extract_chat_usage(response) + self._track_token_usage_internal(usage) + + message = response.choices[0].message + + if message.tool_calls and not available_functions: + self._emit_call_completed_event( + response=list(message.tool_calls), + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return list(message.tool_calls) + + if message.tool_calls and available_functions: + parsed_tool = _first_tool_call_function(message) + if not parsed_tool: + return message.content or "" + function_name, function_args = parsed_tool + + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if result is not None: + return result + + content = message.content or "" + + if self.response_format and isinstance(self.response_format, type): + try: + structured_result = self._validate_structured_output( + content, self.response_format + ) + self._emit_call_completed_event( + response=structured_result, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return structured_result + except ValueError as e: + logging.warning(f"Structured output validation failed: {e}") + + content = self._apply_stop_words(content) + + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + + if usage.get("total_tokens", 0) > 0: + logging.info(f"Cerebras API usage: {usage}") + + content = self._invoke_after_llm_call_hooks( + params["messages"], content, from_agent + ) + except NotFoundError as e: + error_msg = f"Model {self.model} not found: {e}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise ValueError(error_msg) from e + except APIConnectionError as e: + error_msg = f"Failed to connect to Cerebras API: {e}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise ConnectionError(error_msg) from e + except Exception as e: + if is_context_length_exceeded(e): + logging.error(f"Context window exceeded: {e}") + raise LLMContextLengthExceededError(str(e)) from e + + error_msg = f"Cerebras API call failed: {e!s}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise e from e + + return content + + def _handle_streaming_completion( + self, + params: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | list[dict[str, Any]] | BaseModel: + full_response = "" + tool_calls: dict[int, dict[str, Any]] = {} + usage_data: dict[str, Any] | None = None + + if response_model: + completion_stream = self._get_sync_client().chat.completions.create( + **params + ) + + accumulated_content = "" + for chunk in completion_stream: + response_id_stream = chunk.id if hasattr(chunk, "id") else None + + if hasattr(chunk, "usage") and chunk.usage: + usage_data = _extract_chat_usage(chunk) + continue + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + if delta.content: + accumulated_content += delta.content + self._emit_stream_chunk_event( + chunk=delta.content, + from_task=from_task, + from_agent=from_agent, + response_id=response_id_stream, + ) + + if usage_data: + self._track_token_usage_internal(usage_data) + + try: + parsed_object = response_model.model_validate_json(accumulated_content) + + self._emit_call_completed_event( + response=parsed_object.model_dump_json(), + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + + return parsed_object + except Exception as e: + logging.error(f"Failed to parse structured output from stream: {e}") + self._emit_call_completed_event( + response=accumulated_content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + return accumulated_content + + completion_stream = self._get_sync_client().chat.completions.create(**params) + + for completion_chunk in completion_stream: + response_id_stream = ( + completion_chunk.id if hasattr(completion_chunk, "id") else None + ) + + if hasattr(completion_chunk, "usage") and completion_chunk.usage: + usage_data = _extract_chat_usage(completion_chunk) + continue + + if not completion_chunk.choices: + continue + + choice = completion_chunk.choices[0] + chunk_delta = choice.delta + + if chunk_delta.content: + full_response += chunk_delta.content + self._emit_stream_chunk_event( + chunk=chunk_delta.content, + from_task=from_task, + from_agent=from_agent, + response_id=response_id_stream, + ) + + if chunk_delta.tool_calls: + for tool_call in chunk_delta.tool_calls: + tool_index = tool_call.index if tool_call.index is not None else 0 + if tool_index not in tool_calls: + tool_calls[tool_index] = { + "id": tool_call.id, + "name": "", + "arguments": "", + "index": tool_index, + } + elif tool_call.id and not tool_calls[tool_index]["id"]: + tool_calls[tool_index]["id"] = tool_call.id + + if tool_call.function and tool_call.function.name: + tool_calls[tool_index]["name"] = tool_call.function.name + if tool_call.function and tool_call.function.arguments: + tool_calls[tool_index]["arguments"] += ( + tool_call.function.arguments + ) + + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_calls[tool_index]["id"], + "function": { + "name": tool_calls[tool_index]["name"], + "arguments": tool_calls[tool_index]["arguments"], + }, + "type": "function", + "index": tool_calls[tool_index]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + response_id=response_id_stream, + ) + + result = self._finalize_streaming_response( + full_response=full_response, + tool_calls=tool_calls, + usage_data=usage_data, + params=params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + if isinstance(result, str): + return self._invoke_after_llm_call_hooks( + params["messages"], result, from_agent + ) + return result + + async def _ahandle_completion( + self, + params: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + try: + if response_model: + structured_params = { + k: v for k, v in params.items() if k != "response_format" + } + structured_params["response_format"] = generate_model_description( + response_model + ) + parsed_response = ( + await self._get_async_client().chat.completions.create( + **structured_params + ) + ) + usage = _extract_chat_usage(parsed_response) + self._track_token_usage_internal(usage) + message = parsed_response.choices[0].message + content = getattr(message, "content", None) or "" + structured = self._validate_structured_output(content, response_model) + self._emit_call_completed_event( + response=structured.model_dump_json() + if isinstance(structured, BaseModel) + else structured, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return structured + + response = await self._get_async_client().chat.completions.create(**params) + + usage = _extract_chat_usage(response) + self._track_token_usage_internal(usage) + + message = response.choices[0].message + + if message.tool_calls and not available_functions: + self._emit_call_completed_event( + response=list(message.tool_calls), + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return list(message.tool_calls) + + if message.tool_calls and available_functions: + parsed_tool = _first_tool_call_function(message) + if not parsed_tool: + return message.content or "" + function_name, function_args = parsed_tool + + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if result is not None: + return result + + content = message.content or "" + + if self.response_format and isinstance(self.response_format, type): + try: + structured_result = self._validate_structured_output( + content, self.response_format + ) + self._emit_call_completed_event( + response=structured_result, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + return structured_result + except ValueError as e: + logging.warning(f"Structured output validation failed: {e}") + + content = self._apply_stop_words(content) + + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage, + ) + + if usage.get("total_tokens", 0) > 0: + logging.info(f"Cerebras API usage: {usage}") + except NotFoundError as e: + error_msg = f"Model {self.model} not found: {e}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise ValueError(error_msg) from e + except APIConnectionError as e: + error_msg = f"Failed to connect to Cerebras API: {e}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise ConnectionError(error_msg) from e + except Exception as e: + if is_context_length_exceeded(e): + logging.error(f"Context window exceeded: {e}") + raise LLMContextLengthExceededError(str(e)) from e + + error_msg = f"Cerebras API call failed: {e!s}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise e from e + + return content + + async def _ahandle_streaming_completion( + self, + params: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | list[dict[str, Any]] | BaseModel: + full_response = "" + tool_calls: dict[int, dict[str, Any]] = {} + usage_data: dict[str, Any] | None = None + + if response_model: + completion_stream: AsyncIterator[ + Any + ] = await self._get_async_client().chat.completions.create(**params) + + accumulated_content = "" + async for chunk in completion_stream: + response_id_stream = chunk.id if hasattr(chunk, "id") else None + + if hasattr(chunk, "usage") and chunk.usage: + usage_data = _extract_chat_usage(chunk) + continue + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + if delta.content: + accumulated_content += delta.content + self._emit_stream_chunk_event( + chunk=delta.content, + from_task=from_task, + from_agent=from_agent, + response_id=response_id_stream, + ) + + if usage_data: + self._track_token_usage_internal(usage_data) + + try: + parsed_object = response_model.model_validate_json(accumulated_content) + + self._emit_call_completed_event( + response=parsed_object.model_dump_json(), + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + + return parsed_object + except Exception as e: + logging.error(f"Failed to parse structured output from stream: {e}") + self._emit_call_completed_event( + response=accumulated_content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + usage=usage_data, + ) + return accumulated_content + + stream: AsyncIterator[ + Any + ] = await self._get_async_client().chat.completions.create(**params) + + async for chunk in stream: + response_id_stream = chunk.id if hasattr(chunk, "id") else None + + if hasattr(chunk, "usage") and chunk.usage: + usage_data = _extract_chat_usage(chunk) + continue + + if not chunk.choices: + continue + + choice = chunk.choices[0] + chunk_delta = choice.delta + + if chunk_delta.content: + full_response += chunk_delta.content + self._emit_stream_chunk_event( + chunk=chunk_delta.content, + from_task=from_task, + from_agent=from_agent, + response_id=response_id_stream, + ) + + if chunk_delta.tool_calls: + for tool_call in chunk_delta.tool_calls: + tool_index = tool_call.index if tool_call.index is not None else 0 + if tool_index not in tool_calls: + tool_calls[tool_index] = { + "id": tool_call.id, + "name": "", + "arguments": "", + "index": tool_index, + } + elif tool_call.id and not tool_calls[tool_index]["id"]: + tool_calls[tool_index]["id"] = tool_call.id + + if tool_call.function and tool_call.function.name: + tool_calls[tool_index]["name"] = tool_call.function.name + if tool_call.function and tool_call.function.arguments: + tool_calls[tool_index]["arguments"] += ( + tool_call.function.arguments + ) + + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_calls[tool_index]["id"], + "function": { + "name": tool_calls[tool_index]["name"], + "arguments": tool_calls[tool_index]["arguments"], + }, + "type": "function", + "index": tool_calls[tool_index]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + response_id=response_id_stream, + ) + + result = self._finalize_streaming_response( + full_response=full_response, + tool_calls=tool_calls, + usage_data=usage_data, + params=params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + if isinstance(result, str): + return self._invoke_after_llm_call_hooks( + params["messages"], result, from_agent + ) + return result + + def _call_completions( + self, + messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: BaseAgent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + completion_params = self._prepare_completion_params( + messages=messages, tools=tools + ) + + if self.stream: + return self._handle_streaming_completion( + params=completion_params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + return self._handle_completion( + params=completion_params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + async def _acall_completions( + self, + messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: BaseAgent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + completion_params = self._prepare_completion_params( + messages=messages, tools=tools + ) + + if self.stream: + return await self._ahandle_streaming_completion( + params=completion_params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + return await self._ahandle_completion( + params=completion_params, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: BaseAgent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + with llm_call_context(): + try: + self._emit_call_started_event( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + formatted_messages = self._format_messages(messages) + + if not self._invoke_before_llm_call_hooks( + formatted_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + return self._call_completions( + messages=formatted_messages, + tools=tools, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + except Exception as e: + error_msg = f"Cerebras API call failed: {e!s}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: BaseAgent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + with llm_call_context(): + try: + self._emit_call_started_event( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + formatted_messages = self._format_messages(messages) + + if not self._invoke_before_llm_call_hooks( + formatted_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + return await self._acall_completions( + messages=formatted_messages, + tools=tools, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + except Exception as e: + error_msg = f"Cerebras API call failed: {e!s}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise + + def supports_function_calling(self) -> bool: + return True + + def get_context_window_size(self) -> int: + from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO + + return int(8192 * CONTEXT_WINDOW_USAGE_RATIO) diff --git a/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_basic_completion.yaml b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_basic_completion.yaml new file mode 100644 index 000000000..2e662d736 --- /dev/null +++ b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_basic_completion.yaml @@ -0,0 +1,233 @@ +interactions: +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:47 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:47 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '{"model":"llama3.1-8b","max_completion_tokens":32,"messages":[{"role":"user","content":"Reply + with exactly the word: OK"}]}' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + content-length: + - '123' + content-type: + - application/json + cookie: + - COOKIE-XXX + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: POST + uri: https://api.cerebras.ai/v1/chat/completions + response: + body: + string: '{"id":"chatcmpl-63f7fec0-7ac5-4dec-adca-427a39ca8d77","choices":[{"finish_reason":"stop","index":0,"message":{"content":"OK","role":"assistant"}}],"created":1778106827,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion","usage":{"total_tokens":44,"completion_tokens":2,"completion_tokens_details":{"accepted_prediction_tokens":0,"rejected_prediction_tokens":0,"reasoning_tokens":0},"prompt_tokens":42,"prompt_tokens_details":{"cached_tokens":0}},"time_info":{"created":1778106827.7362924,"queue_time":2.178798466,"prompt_time":0.003727911,"completion_time":0.00089246,"total_time":2.1848058700561523}}' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - application/json + date: + - Wed, 06 May 2026 22:33:49 GMT + inference-id: + - chatcmpl-63f7fec0-7ac5-4dec-adca-427a39ca8d77 + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + x-ratelimit-limit-requests-day: + - '14400' + x-ratelimit-limit-requests-hour: + - '900' + x-ratelimit-limit-requests-minute: + - '30' + x-ratelimit-limit-tokens-day: + - '1000000' + x-ratelimit-limit-tokens-hour: + - '1000000' + x-ratelimit-limit-tokens-minute: + - '60000' + x-ratelimit-remaining-requests-day: + - '14399' + x-ratelimit-remaining-requests-hour: + - '899' + x-ratelimit-remaining-requests-minute: + - '29' + x-ratelimit-remaining-tokens-day: + - '999961' + x-ratelimit-remaining-tokens-hour: + - '999961' + x-ratelimit-remaining-tokens-minute: + - '59961' + x-request-id: + - X-REQUEST-ID-XXX + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_streaming_completion.yaml b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_streaming_completion.yaml new file mode 100644 index 000000000..1e2a765be --- /dev/null +++ b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_streaming_completion.yaml @@ -0,0 +1,249 @@ +interactions: +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:50 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:50 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '{"model":"llama3.1-8b","max_completion_tokens":32,"messages":[{"role":"user","content":"Count: + one, two, three."}],"stream":true,"stream_options":{"include_usage":true}}' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + content-length: + - '169' + content-type: + - application/json + cookie: + - COOKIE-XXX + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: POST + uri: https://api.cerebras.ai/v1/chat/completions + response: + body: + string: 'data: {"id":"chatcmpl-e4b66146-cac9-459a-9909-ddd5fb56b3fd","choices":[{"delta":{"role":"assistant"},"index":0}],"created":1778106830,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion.chunk"} + + + data: {"id":"chatcmpl-e4b66146-cac9-459a-9909-ddd5fb56b3fd","choices":[{"delta":{"content":"One"},"index":0}],"created":1778106830,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion.chunk"} + + + data: {"id":"chatcmpl-e4b66146-cac9-459a-9909-ddd5fb56b3fd","choices":[{"delta":{"content":", + two, three. What''s next?"},"index":0}],"created":1778106830,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion.chunk"} + + + data: {"id":"chatcmpl-e4b66146-cac9-459a-9909-ddd5fb56b3fd","choices":[{"delta":{},"finish_reason":"stop","index":0}],"created":1778106830,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion.chunk","usage":{"total_tokens":54,"completion_tokens":11,"completion_tokens_details":{"accepted_prediction_tokens":0,"rejected_prediction_tokens":0,"reasoning_tokens":0},"prompt_tokens":43,"prompt_tokens_details":{"cached_tokens":0}},"time_info":{"created":1778106830.7255669,"queue_time":0.00023963,"prompt_time":0.002608115,"completion_time":0.00474267,"total_time":0.009609460830688477}} + + + data: [DONE] + + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/event-stream; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:50 GMT + inference-id: + - chatcmpl-e4b66146-cac9-459a-9909-ddd5fb56b3fd + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + x-ratelimit-limit-requests-day: + - '14400' + x-ratelimit-limit-requests-hour: + - '900' + x-ratelimit-limit-requests-minute: + - '30' + x-ratelimit-limit-tokens-day: + - '1000000' + x-ratelimit-limit-tokens-hour: + - '1000000' + x-ratelimit-limit-tokens-minute: + - '60000' + x-ratelimit-remaining-requests-day: + - '14398' + x-ratelimit-remaining-requests-hour: + - '898' + x-ratelimit-remaining-requests-minute: + - '29' + x-ratelimit-remaining-tokens-day: + - '999953' + x-ratelimit-remaining-tokens-hour: + - '999963' + x-ratelimit-remaining-tokens-minute: + - '59963' + x-request-id: + - X-REQUEST-ID-XXX + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_temperature_and_seed_passed_to_sdk.yaml b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_temperature_and_seed_passed_to_sdk.yaml new file mode 100644 index 000000000..c513a5b6c --- /dev/null +++ b/lib/crewai/tests/cassettes/llms/cerebras/test_cerebras_temperature_and_seed_passed_to_sdk.yaml @@ -0,0 +1,234 @@ +interactions: +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:51 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: GET + uri: https://api.cerebras.ai/v1/tcp_warming + response: + body: + string: 'This request is being sent by the Cerebras Cloud SDK to warm up your + TCP connection so that your requests will have lower TTFT. + + If you don''t want this, please set `"warmTCPConnection": false` (NodeJS) + or `warm_tcp_connection=False` (Python) in the SDK constructor. + + + For more assistance, contact us at support@cerebras.ai + + ' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - text/plain; charset=utf-8 + date: + - Wed, 06 May 2026 22:33:51 GMT + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + set-cookie: + - SET-COOKIE-XXX + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + status: + code: 200 + message: OK +- request: + body: '{"model":"llama3.1-8b","max_completion_tokens":64,"messages":[{"role":"user","content":"Say + hi."}],"seed":7,"temperature":0.0}' + headers: + accept: + - application/json + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + content-length: + - '126' + content-type: + - application/json + cookie: + - COOKIE-XXX + host: + - api.cerebras.ai + user-agent: + - X-USER-AGENT-XXX + x-stainless-arch: + - X-STAINLESS-ARCH-XXX + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - X-STAINLESS-OS-XXX + x-stainless-package-version: + - 1.67.0 + x-stainless-read-timeout: + - X-STAINLESS-READ-TIMEOUT-XXX + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.3 + method: POST + uri: https://api.cerebras.ai/v1/chat/completions + response: + body: + string: '{"id":"chatcmpl-5bf61f2f-c4a5-4122-b921-d8cba7d8ebc0","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello. + How can I assist you today?","role":"assistant"}}],"created":1778106831,"model":"llama3.1-8b","system_fingerprint":"fp_96e7e4453bc38316a23a","object":"chat.completion","usage":{"total_tokens":48,"completion_tokens":10,"completion_tokens_details":{"accepted_prediction_tokens":0,"rejected_prediction_tokens":0,"reasoning_tokens":0},"prompt_tokens":38,"prompt_tokens_details":{"cached_tokens":0}},"time_info":{"created":1778106831.5465255,"queue_time":7.401e-05,"prompt_time":0.001719314,"completion_time":0.003982367,"total_time":0.0068645477294921875}}' + headers: + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + cf-ray: + - CF-RAY-XXX + content-type: + - application/json + date: + - Wed, 06 May 2026 22:33:51 GMT + inference-id: + - chatcmpl-5bf61f2f-c4a5-4122-b921-d8cba7d8ebc0 + referrer-policy: + - REFERRER-POLICY-XXX + server: + - cloudflare + strict-transport-security: + - STS-XXX + x-content-type-options: + - X-CONTENT-TYPE-XXX + x-ratelimit-limit-requests-day: + - '14400' + x-ratelimit-limit-requests-hour: + - '900' + x-ratelimit-limit-requests-minute: + - '30' + x-ratelimit-limit-tokens-day: + - '1000000' + x-ratelimit-limit-tokens-hour: + - '1000000' + x-ratelimit-limit-tokens-minute: + - '60000' + x-ratelimit-remaining-requests-day: + - '14397' + x-ratelimit-remaining-requests-hour: + - '898' + x-ratelimit-remaining-requests-minute: + - '28' + x-ratelimit-remaining-tokens-day: + - '999871' + x-ratelimit-remaining-tokens-hour: + - '999881' + x-ratelimit-remaining-tokens-minute: + - '59881' + x-request-id: + - X-REQUEST-ID-XXX + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai/tests/llms/cerebras/__init__.py b/lib/crewai/tests/llms/cerebras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/tests/llms/cerebras/test_cerebras_completion.py b/lib/crewai/tests/llms/cerebras/test_cerebras_completion.py new file mode 100644 index 000000000..ed3039710 --- /dev/null +++ b/lib/crewai/tests/llms/cerebras/test_cerebras_completion.py @@ -0,0 +1,284 @@ +"""Tests for the native Cerebras provider. + +Two flavors: +- Unit tests under the ``Test*`` classes — exercise factory routing, field + normalization, env-var resolution, client construction, and the fallback + to OpenAI-compatible when the SDK extra is not installed. These run with + ``CEREBRAS_API_KEY`` / ``CEREBRAS_BASE_URL`` cleared so each test states + the env it depends on. +- Module-level VCR tests — replay real Cerebras API responses from cassettes. + To re-record, set ``CEREBRAS_API_KEY`` and run with + ``PYTEST_VCR_RECORD_MODE=new_episodes`` (or delete the target cassette + and use the default ``once`` mode). +""" + +from __future__ import annotations + +import builtins +import sys +from unittest.mock import patch + +import pytest + +from crewai.llm import LLM +from crewai.llms.providers.cerebras.completion import CerebrasCompletion +from crewai.llms.providers.openai_compatible.completion import ( + OpenAICompatibleCompletion, +) + + +@pytest.fixture +def clear_cerebras_env(monkeypatch): + monkeypatch.delenv("CEREBRAS_API_KEY", raising=False) + monkeypatch.delenv("CEREBRAS_BASE_URL", raising=False) + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasFactoryRouting: + def test_provider_prefix_routes_to_native(self): + llm = LLM(model="cerebras/gpt-oss-120b", api_key="sk-test") + assert isinstance(llm, CerebrasCompletion) + assert llm.llm_type == "cerebras" + assert llm.provider == "cerebras" + assert llm.is_litellm is False + + def test_explicit_provider_kwarg_routes_to_native(self): + llm = LLM(model="gpt-oss-120b", provider="cerebras", api_key="sk-test") + assert isinstance(llm, CerebrasCompletion) + + def test_falls_back_to_openai_compat_when_sdk_missing(self, monkeypatch): + # Drop any cached imports so the factory re-imports the cerebras module. + for mod_name in list(sys.modules): + if mod_name.startswith( + "crewai.llms.providers.cerebras" + ) or mod_name.startswith("cerebras"): + monkeypatch.delitem(sys.modules, mod_name, raising=False) + + real_import = builtins.__import__ + + def _import_blocker(name, *args, **kwargs): + if name.startswith("cerebras"): + raise ImportError(f"simulated missing dep: {name}") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _import_blocker) + + llm = LLM(model="cerebras/gpt-oss-120b", api_key="sk-test") + # Without the SDK, factory falls through to the OpenAI-compatible path. + assert isinstance(llm, OpenAICompatibleCompletion) + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasFieldNormalization: + def test_default_base_url_left_unset(self): + # We deliberately don't pin a default base URL — the SDK has its own. + llm = CerebrasCompletion(model="gpt-oss-120b", api_key="sk-test") + assert llm.base_url is None + assert llm.api == "completions" + assert llm.provider == "cerebras" + + def test_env_var_base_url_override(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_BASE_URL", "https://custom.cerebras.example/v1") + llm = CerebrasCompletion(model="gpt-oss-120b", api_key="sk-test") + assert llm.base_url == "https://custom.cerebras.example/v1" + + def test_explicit_base_url_takes_precedence(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_BASE_URL", "https://from-env.example/v1") + llm = CerebrasCompletion( + model="gpt-oss-120b", + api_key="sk-test", + base_url="https://explicit.example/v1", + ) + assert llm.base_url == "https://explicit.example/v1" + + def test_api_forced_to_completions(self): + # Even if a caller tries to set api="responses", the validator clamps it. + llm = CerebrasCompletion( + model="gpt-oss-120b", api_key="sk-test", api="responses" + ) + assert llm.api == "completions" + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasApiKeyResolution: + def test_env_var_picked_up(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "env-key") + llm = CerebrasCompletion(model="gpt-oss-120b") + assert llm.api_key == "env-key" + + def test_explicit_api_key_takes_precedence(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "env-key") + llm = CerebrasCompletion(model="gpt-oss-120b", api_key="explicit-key") + assert llm.api_key == "explicit-key" + + def test_construction_succeeds_without_key(self): + # Lazy client init: missing key should not crash construction. + llm = CerebrasCompletion(model="gpt-oss-120b") + assert llm.api_key is None + + def test_get_client_params_raises_without_key(self): + llm = CerebrasCompletion(model="gpt-oss-120b") + with pytest.raises(ValueError, match="CEREBRAS_API_KEY"): + llm._get_client_params() + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasClientBuild: + def test_sync_client_uses_cerebras_sdk(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "sk-test") + llm = CerebrasCompletion(model="gpt-oss-120b") + client = llm._get_sync_client() + from cerebras.cloud.sdk import Cerebras + + assert isinstance(client, Cerebras) + + def test_async_client_uses_cerebras_sdk(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "sk-test") + llm = CerebrasCompletion(model="gpt-oss-120b") + client = llm._get_async_client() + from cerebras.cloud.sdk import AsyncCerebras + + assert isinstance(client, AsyncCerebras) + + def test_client_params_threaded_through(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "sk-test") + llm = CerebrasCompletion( + model="gpt-oss-120b", + timeout=30.0, + max_retries=5, + default_headers={"X-Custom": "yes"}, + ) + params = llm._get_client_params() + assert params["timeout"] == 30.0 + assert params["max_retries"] == 5 + assert params["default_headers"] == {"X-Custom": "yes"} + assert params["api_key"] == "sk-test" + # base_url omitted so the SDK uses its own default. + assert "base_url" not in params + + def test_client_params_includes_base_url_when_set(self, monkeypatch): + monkeypatch.setenv("CEREBRAS_API_KEY", "sk-test") + llm = CerebrasCompletion( + model="gpt-oss-120b", base_url="https://override.example/api" + ) + params = llm._get_client_params() + assert params["base_url"] == "https://override.example/api" + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasConfigDict: + def test_specific_fields_included_when_set(self): + llm = CerebrasCompletion( + model="gpt-oss-120b", + api_key="sk-test", + service_tier="priority", + prompt_cache_key="cache-1", + clear_thinking=True, + reasoning_effort="high", + ) + config = llm.to_config_dict() + assert config["service_tier"] == "priority" + assert config["prompt_cache_key"] == "cache-1" + assert config["clear_thinking"] is True + assert config["reasoning_effort"] == "high" + + def test_specific_fields_omitted_when_unset(self): + llm = CerebrasCompletion(model="gpt-oss-120b", api_key="sk-test") + config = llm.to_config_dict() + assert "service_tier" not in config + assert "prompt_cache_key" not in config + assert "clear_thinking" not in config + + +@pytest.mark.usefixtures("clear_cerebras_env") +class TestCerebrasCompletionParams: + """Verify Cerebras-specific kwargs reach the chat.completions.create call.""" + + def test_reasoning_effort_threaded_for_non_o1_models(self): + llm = CerebrasCompletion( + model="gpt-oss-120b", api_key="sk-test", reasoning_effort="high" + ) + params = llm._prepare_completion_params(messages=[]) + # Parent gates this on is_o1_model — Cerebras must thread it regardless. + assert params["reasoning_effort"] == "high" + + def test_service_tier_threaded(self): + llm = CerebrasCompletion( + model="llama3.1-8b", api_key="sk-test", service_tier="priority" + ) + params = llm._prepare_completion_params(messages=[]) + assert params["service_tier"] == "priority" + + def test_prompt_cache_key_threaded(self): + llm = CerebrasCompletion( + model="llama3.1-8b", api_key="sk-test", prompt_cache_key="run-42" + ) + params = llm._prepare_completion_params(messages=[]) + assert params["prompt_cache_key"] == "run-42" + + def test_clear_thinking_threaded(self): + llm = CerebrasCompletion( + model="zai-glm-4.7", api_key="sk-test", clear_thinking=True + ) + params = llm._prepare_completion_params(messages=[]) + assert params["clear_thinking"] is True + + def test_cerebras_specific_fields_omitted_when_unset(self): + llm = CerebrasCompletion(model="llama3.1-8b", api_key="sk-test") + params = llm._prepare_completion_params(messages=[]) + assert "service_tier" not in params + assert "prompt_cache_key" not in params + assert "clear_thinking" not in params + assert "reasoning_effort" not in params + + +@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) +def test_cerebras_basic_completion(): + """End-to-end completion against Cerebras (replays from cassette).""" + llm = LLM(model="cerebras/llama3.1-8b", max_completion_tokens=32) + assert isinstance(llm, CerebrasCompletion) + + result = llm.call("Reply with exactly the word: OK") + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) +def test_cerebras_streaming_completion(): + """Streaming completion against Cerebras (replays from cassette).""" + llm = LLM(model="cerebras/llama3.1-8b", stream=True, max_completion_tokens=32) + assert isinstance(llm, CerebrasCompletion) + + result = llm.call("Count: one, two, three.") + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) +def test_cerebras_temperature_and_seed_passed_to_sdk(): + """Deterministic-sampling params reach the Cerebras SDK call.""" + llm = LLM( + model="cerebras/llama3.1-8b", + temperature=0.0, + seed=7, + max_completion_tokens=64, + ) + assert isinstance(llm, CerebrasCompletion) + + original_create = llm._client.chat.completions.create + captured: dict = {} + + def capture_and_call(**kwargs): + captured.update(kwargs) + return original_create(**kwargs) + + with patch.object( + llm._client.chat.completions, "create", side_effect=capture_and_call + ): + llm.call("Say hi.") + + assert captured["model"] == "llama3.1-8b" + assert captured["temperature"] == 0.0 + assert captured["seed"] == 7 diff --git a/uv.lock b/uv.lock index 0c91bdd1f..bd2ad6974 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-27T16:00:00Z" +exclude-newer = "2026-04-28T07:00:00Z" [manifest] members = [ @@ -789,6 +789,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, ] +[[package]] +name = "cerebras-cloud-sdk" +version = "1.67.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx", extra = ["http2"] }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/12/c201f07582068141e88f9a523ab02fdc97de58f2f7c0df775c6c52b9d8dd/cerebras_cloud_sdk-1.67.0.tar.gz", hash = "sha256:3aed6f86c6c7a83ee9d4cfb08a2acea089cebf2af5b8aed116ef79995a4f4813", size = 131536, upload-time = "2026-01-29T23:31:27.306Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/36a364f3d1bab4073454b75e7c91dc7ec6879b960063d1a9c929f1c7ea71/cerebras_cloud_sdk-1.67.0-py3-none-any.whl", hash = "sha256:658b79ca2e9c16f75cc6b4e5d523ee014c9e54a88bd39f88905c28ecb33daae1", size = 97807, upload-time = "2026-01-29T23:31:25.77Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -1330,6 +1347,9 @@ azure-ai-inference = [ bedrock = [ { name = "boto3" }, ] +cerebras = [ + { name = "cerebras-cloud-sdk" }, +] docling = [ { name = "docling" }, ] @@ -1383,6 +1403,7 @@ requires-dist = [ { name = "azure-identity", marker = "extra == 'azure-ai-inference'", specifier = ">=1.17.0,<2" }, { name = "boto3", marker = "extra == 'aws'", specifier = "~=1.42.79" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = "~=1.42.79" }, + { name = "cerebras-cloud-sdk", marker = "extra == 'cerebras'", specifier = "~=1.67.0" }, { name = "chromadb", specifier = "~=1.1.0" }, { name = "click", specifier = "~=8.1.7" }, { name = "crewai-cli", editable = "lib/cli" }, @@ -1426,7 +1447,7 @@ requires-dist = [ { name = "tomli-w", specifier = "~=1.1.0" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" }, ] -provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"] +provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "cerebras", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"] [[package]] name = "crewai-cli"