mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: enhance InternalInstructor to support multiple LLM providers (#3767)
* feat: enhance InternalInstructor to support multiple LLM providers - Updated InternalInstructor to conditionally create an instructor client based on the LLM provider. - Introduced a new method _create_instructor_client to handle client creation using the modern from_provider pattern. - Added functionality to extract the provider from the LLM model name. - Implemented tests for InternalInstructor with various LLM providers including OpenAI, Anthropic, Gemini, and Azure, ensuring robust integration and error handling. This update improves flexibility and extensibility for different LLM integrations. * fix test
This commit is contained in:
@@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Generic, TypeGuard, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
@@ -11,9 +13,6 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@@ -62,9 +61,59 @@ class InternalInstructor(Generic[T]):
|
||||
|
||||
with suppress_warnings():
|
||||
import instructor # type: ignore[import-untyped]
|
||||
from litellm import completion
|
||||
|
||||
self._client = instructor.from_litellm(completion)
|
||||
if (
|
||||
self.llm is not None
|
||||
and hasattr(self.llm, "is_litellm")
|
||||
and self.llm.is_litellm
|
||||
):
|
||||
from litellm import completion
|
||||
|
||||
self._client = instructor.from_litellm(completion)
|
||||
else:
|
||||
self._client = self._create_instructor_client()
|
||||
|
||||
def _create_instructor_client(self) -> Any:
|
||||
"""Create instructor client using the modern from_provider pattern.
|
||||
|
||||
Returns:
|
||||
Instructor client configured for the LLM provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
"""
|
||||
import instructor
|
||||
|
||||
if isinstance(self.llm, str):
|
||||
model_string = self.llm
|
||||
elif self.llm is not None and hasattr(self.llm, "model"):
|
||||
model_string = self.llm.model
|
||||
else:
|
||||
raise ValueError("LLM must be a string or have a model attribute")
|
||||
|
||||
if isinstance(self.llm, str):
|
||||
provider = self._extract_provider()
|
||||
elif self.llm is not None and hasattr(self.llm, "provider"):
|
||||
provider = self.llm.provider
|
||||
else:
|
||||
provider = "openai" # Default fallback
|
||||
|
||||
return instructor.from_provider(f"{provider}/{model_string}")
|
||||
|
||||
def _extract_provider(self) -> str:
|
||||
"""Extract provider from LLM model name.
|
||||
|
||||
Returns:
|
||||
Provider name (e.g., 'openai', 'anthropic', etc.)
|
||||
"""
|
||||
if self.llm is not None and hasattr(self.llm, "provider") and self.llm.provider:
|
||||
return self.llm.provider
|
||||
|
||||
if isinstance(self.llm, str):
|
||||
return self.llm.partition("/")[0] or "openai"
|
||||
if self.llm is not None and hasattr(self.llm, "model"):
|
||||
return self.llm.model.partition("/")[0] or "openai"
|
||||
return "openai"
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert the structured output to JSON format.
|
||||
@@ -96,6 +145,6 @@ class InternalInstructor(Generic[T]):
|
||||
else:
|
||||
model_name = self.llm.model
|
||||
|
||||
return self._client.chat.completions.create(
|
||||
return self._client.chat.completions.create( # type: ignore[no-any-return]
|
||||
model=model_name, response_model=self.model, messages=messages
|
||||
)
|
||||
|
||||
@@ -902,7 +902,8 @@ def test_agent_step_callback():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_function_calling_llm():
|
||||
llm = "gpt-4o"
|
||||
from crewai.llm import LLM
|
||||
llm = LLM(model="gpt-4o", is_litellm=True)
|
||||
|
||||
@tool
|
||||
def learn_about_ai() -> str:
|
||||
|
||||
@@ -22,7 +22,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(request) -> dict:
|
||||
def vcr_config(request: pytest.FixtureRequest) -> dict[str, str]:
|
||||
return {
|
||||
"cassette_library_dir": os.path.join(os.path.dirname(__file__), "cassettes"),
|
||||
}
|
||||
@@ -65,7 +65,7 @@ class CustomConverter(Converter):
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
def mock_agent() -> Mock:
|
||||
agent = Mock()
|
||||
agent.function_calling_llm = None
|
||||
agent.llm = Mock()
|
||||
@@ -73,7 +73,7 @@ def mock_agent():
|
||||
|
||||
|
||||
# Tests for convert_to_model
|
||||
def test_convert_to_model_with_valid_json():
|
||||
def test_convert_to_model_with_valid_json() -> None:
|
||||
result = '{"name": "John", "age": 30}'
|
||||
output = convert_to_model(result, SimpleModel, None, None)
|
||||
assert isinstance(output, SimpleModel)
|
||||
@@ -81,7 +81,7 @@ def test_convert_to_model_with_valid_json():
|
||||
assert output.age == 30
|
||||
|
||||
|
||||
def test_convert_to_model_with_invalid_json():
|
||||
def test_convert_to_model_with_invalid_json() -> None:
|
||||
result = '{"name": "John", "age": "thirty"}'
|
||||
with patch("crewai.utilities.converter.handle_partial_json") as mock_handle:
|
||||
mock_handle.return_value = "Fallback result"
|
||||
@@ -89,13 +89,13 @@ def test_convert_to_model_with_invalid_json():
|
||||
assert output == "Fallback result"
|
||||
|
||||
|
||||
def test_convert_to_model_with_no_model():
|
||||
def test_convert_to_model_with_no_model() -> None:
|
||||
result = "Plain text"
|
||||
output = convert_to_model(result, None, None, None)
|
||||
assert output == "Plain text"
|
||||
|
||||
|
||||
def test_convert_to_model_with_special_characters():
|
||||
def test_convert_to_model_with_special_characters() -> None:
|
||||
json_string_test = """
|
||||
{
|
||||
"responses": [
|
||||
@@ -114,7 +114,7 @@ def test_convert_to_model_with_special_characters():
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_model_with_escaped_special_characters():
|
||||
def test_convert_to_model_with_escaped_special_characters() -> None:
|
||||
json_string_test = json.dumps(
|
||||
{
|
||||
"responses": [
|
||||
@@ -133,7 +133,7 @@ def test_convert_to_model_with_escaped_special_characters():
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_model_with_multiple_special_characters():
|
||||
def test_convert_to_model_with_multiple_special_characters() -> None:
|
||||
json_string_test = """
|
||||
{
|
||||
"responses": [
|
||||
@@ -153,7 +153,7 @@ def test_convert_to_model_with_multiple_special_characters():
|
||||
|
||||
|
||||
# Tests for validate_model
|
||||
def test_validate_model_pydantic_output():
|
||||
def test_validate_model_pydantic_output() -> None:
|
||||
result = '{"name": "Alice", "age": 25}'
|
||||
output = validate_model(result, SimpleModel, False)
|
||||
assert isinstance(output, SimpleModel)
|
||||
@@ -161,7 +161,7 @@ def test_validate_model_pydantic_output():
|
||||
assert output.age == 25
|
||||
|
||||
|
||||
def test_validate_model_json_output():
|
||||
def test_validate_model_json_output() -> None:
|
||||
result = '{"name": "Bob", "age": 40}'
|
||||
output = validate_model(result, SimpleModel, True)
|
||||
assert isinstance(output, dict)
|
||||
@@ -169,7 +169,7 @@ def test_validate_model_json_output():
|
||||
|
||||
|
||||
# Tests for handle_partial_json
|
||||
def test_handle_partial_json_with_valid_partial():
|
||||
def test_handle_partial_json_with_valid_partial() -> None:
|
||||
result = 'Some text {"name": "Charlie", "age": 35} more text'
|
||||
output = handle_partial_json(result, SimpleModel, False, None)
|
||||
assert isinstance(output, SimpleModel)
|
||||
@@ -177,7 +177,7 @@ def test_handle_partial_json_with_valid_partial():
|
||||
assert output.age == 35
|
||||
|
||||
|
||||
def test_handle_partial_json_with_invalid_partial(mock_agent):
|
||||
def test_handle_partial_json_with_invalid_partial(mock_agent: Mock) -> None:
|
||||
result = "No valid JSON here"
|
||||
with patch("crewai.utilities.converter.convert_with_instructions") as mock_convert:
|
||||
mock_convert.return_value = "Converted result"
|
||||
@@ -189,8 +189,8 @@ def test_handle_partial_json_with_invalid_partial(mock_agent):
|
||||
@patch("crewai.utilities.converter.create_converter")
|
||||
@patch("crewai.utilities.converter.get_conversion_instructions")
|
||||
def test_convert_with_instructions_success(
|
||||
mock_get_instructions, mock_create_converter, mock_agent
|
||||
):
|
||||
mock_get_instructions: Mock, mock_create_converter: Mock, mock_agent: Mock
|
||||
) -> None:
|
||||
mock_get_instructions.return_value = "Instructions"
|
||||
mock_converter = Mock()
|
||||
mock_converter.to_pydantic.return_value = SimpleModel(name="David", age=50)
|
||||
@@ -207,8 +207,8 @@ def test_convert_with_instructions_success(
|
||||
@patch("crewai.utilities.converter.create_converter")
|
||||
@patch("crewai.utilities.converter.get_conversion_instructions")
|
||||
def test_convert_with_instructions_failure(
|
||||
mock_get_instructions, mock_create_converter, mock_agent
|
||||
):
|
||||
mock_get_instructions: Mock, mock_create_converter: Mock, mock_agent: Mock
|
||||
) -> None:
|
||||
mock_get_instructions.return_value = "Instructions"
|
||||
mock_converter = Mock()
|
||||
mock_converter.to_pydantic.return_value = ConverterError("Conversion failed")
|
||||
@@ -222,7 +222,7 @@ def test_convert_with_instructions_failure(
|
||||
|
||||
|
||||
# Tests for get_conversion_instructions
|
||||
def test_get_conversion_instructions_gpt():
|
||||
def test_get_conversion_instructions_gpt() -> None:
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
with patch.object(LLM, "supports_function_calling") as supports_function_calling:
|
||||
supports_function_calling.return_value = True
|
||||
@@ -237,7 +237,7 @@ def test_get_conversion_instructions_gpt():
|
||||
assert instructions == expected_instructions
|
||||
|
||||
|
||||
def test_get_conversion_instructions_non_gpt():
|
||||
def test_get_conversion_instructions_non_gpt() -> None:
|
||||
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
|
||||
with patch.object(LLM, "supports_function_calling", return_value=False):
|
||||
instructions = get_conversion_instructions(SimpleModel, llm)
|
||||
@@ -246,17 +246,17 @@ def test_get_conversion_instructions_non_gpt():
|
||||
|
||||
|
||||
# Tests for is_gpt
|
||||
def test_supports_function_calling_true():
|
||||
def test_supports_function_calling_true() -> None:
|
||||
llm = LLM(model="gpt-4o")
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
|
||||
def test_supports_function_calling_false():
|
||||
def test_supports_function_calling_false() -> None:
|
||||
llm = LLM(model="non-existent-model", is_litellm=True)
|
||||
assert llm.supports_function_calling() is False
|
||||
|
||||
|
||||
def test_create_converter_with_mock_agent():
|
||||
def test_create_converter_with_mock_agent() -> None:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.get_output_converter.return_value = MagicMock(spec=Converter)
|
||||
|
||||
@@ -272,7 +272,7 @@ def test_create_converter_with_mock_agent():
|
||||
mock_agent.get_output_converter.assert_called_once()
|
||||
|
||||
|
||||
def test_create_converter_with_custom_converter():
|
||||
def test_create_converter_with_custom_converter() -> None:
|
||||
converter = create_converter(
|
||||
converter_cls=CustomConverter,
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
@@ -284,7 +284,7 @@ def test_create_converter_with_custom_converter():
|
||||
assert isinstance(converter, CustomConverter)
|
||||
|
||||
|
||||
def test_create_converter_fails_without_agent_or_converter_cls():
|
||||
def test_create_converter_fails_without_agent_or_converter_cls() -> None:
|
||||
with pytest.raises(
|
||||
ValueError, match="Either agent or converter_cls must be provided"
|
||||
):
|
||||
@@ -293,13 +293,13 @@ def test_create_converter_fails_without_agent_or_converter_cls():
|
||||
)
|
||||
|
||||
|
||||
def test_generate_model_description_simple_model():
|
||||
def test_generate_model_description_simple_model() -> None:
|
||||
description = generate_model_description(SimpleModel)
|
||||
expected_description = '{\n "name": str,\n "age": int\n}'
|
||||
assert description == expected_description
|
||||
|
||||
|
||||
def test_generate_model_description_nested_model():
|
||||
def test_generate_model_description_nested_model() -> None:
|
||||
description = generate_model_description(NestedModel)
|
||||
expected_description = (
|
||||
'{\n "id": int,\n "data": {\n "name": str,\n "age": int\n}\n}'
|
||||
@@ -307,7 +307,7 @@ def test_generate_model_description_nested_model():
|
||||
assert description == expected_description
|
||||
|
||||
|
||||
def test_generate_model_description_optional_field():
|
||||
def test_generate_model_description_optional_field() -> None:
|
||||
class ModelWithOptionalField(BaseModel):
|
||||
name: str
|
||||
age: int | None
|
||||
@@ -317,7 +317,7 @@ def test_generate_model_description_optional_field():
|
||||
assert description == expected_description
|
||||
|
||||
|
||||
def test_generate_model_description_list_field():
|
||||
def test_generate_model_description_list_field() -> None:
|
||||
class ModelWithListField(BaseModel):
|
||||
items: list[int]
|
||||
|
||||
@@ -326,7 +326,7 @@ def test_generate_model_description_list_field():
|
||||
assert description == expected_description
|
||||
|
||||
|
||||
def test_generate_model_description_dict_field():
|
||||
def test_generate_model_description_dict_field() -> None:
|
||||
class ModelWithDictField(BaseModel):
|
||||
attributes: dict[str, int]
|
||||
|
||||
@@ -336,7 +336,7 @@ def test_generate_model_description_dict_field():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_convert_with_instructions():
|
||||
def test_convert_with_instructions() -> None:
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
sample_text = "Name: Alice, Age: 30"
|
||||
|
||||
@@ -358,7 +358,7 @@ def test_convert_with_instructions():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_converter_with_llama3_2_model():
|
||||
def test_converter_with_llama3_2_model() -> None:
|
||||
llm = LLM(model="openrouter/meta-llama/llama-3.2-3b-instruct")
|
||||
sample_text = "Name: Alice Llama, Age: 30"
|
||||
instructions = get_conversion_instructions(SimpleModel, llm)
|
||||
@@ -375,7 +375,7 @@ def test_converter_with_llama3_2_model():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_converter_with_llama3_1_model():
|
||||
def test_converter_with_llama3_1_model() -> None:
|
||||
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
|
||||
sample_text = "Name: Alice Llama, Age: 30"
|
||||
instructions = get_conversion_instructions(SimpleModel, llm)
|
||||
@@ -392,7 +392,7 @@ def test_converter_with_llama3_1_model():
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_converter_with_nested_model():
|
||||
def test_converter_with_nested_model() -> None:
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
|
||||
|
||||
@@ -416,7 +416,7 @@ def test_converter_with_nested_model():
|
||||
|
||||
|
||||
# Tests for error handling
|
||||
def test_converter_error_handling():
|
||||
def test_converter_error_handling() -> None:
|
||||
llm = Mock(spec=LLM)
|
||||
llm.supports_function_calling.return_value = False
|
||||
llm.call.return_value = "Invalid JSON"
|
||||
@@ -437,7 +437,7 @@ def test_converter_error_handling():
|
||||
|
||||
|
||||
# Tests for retry logic
|
||||
def test_converter_retry_logic():
|
||||
def test_converter_retry_logic() -> None:
|
||||
llm = Mock(spec=LLM)
|
||||
llm.supports_function_calling.return_value = False
|
||||
llm.call.side_effect = [
|
||||
@@ -465,7 +465,7 @@ def test_converter_retry_logic():
|
||||
|
||||
|
||||
# Tests for optional fields
|
||||
def test_converter_with_optional_fields():
|
||||
def test_converter_with_optional_fields() -> None:
|
||||
class OptionalModel(BaseModel):
|
||||
name: str
|
||||
age: int | None
|
||||
@@ -492,7 +492,7 @@ def test_converter_with_optional_fields():
|
||||
|
||||
|
||||
# Tests for list fields
|
||||
def test_converter_with_list_field():
|
||||
def test_converter_with_list_field() -> None:
|
||||
class ListModel(BaseModel):
|
||||
items: list[int]
|
||||
|
||||
@@ -515,7 +515,7 @@ def test_converter_with_list_field():
|
||||
assert output.items == [1, 2, 3]
|
||||
|
||||
|
||||
def test_converter_with_enum():
|
||||
def test_converter_with_enum() -> None:
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
@@ -546,7 +546,7 @@ def test_converter_with_enum():
|
||||
|
||||
|
||||
# Tests for ambiguous input
|
||||
def test_converter_with_ambiguous_input():
|
||||
def test_converter_with_ambiguous_input() -> None:
|
||||
llm = Mock(spec=LLM)
|
||||
llm.supports_function_calling.return_value = False
|
||||
llm.call.return_value = '{"name": "Charlie", "age": "Not an age"}'
|
||||
@@ -567,7 +567,7 @@ def test_converter_with_ambiguous_input():
|
||||
|
||||
|
||||
# Tests for function calling support
|
||||
def test_converter_with_function_calling():
|
||||
def test_converter_with_function_calling() -> None:
|
||||
llm = Mock(spec=LLM)
|
||||
llm.supports_function_calling.return_value = True
|
||||
|
||||
@@ -580,20 +580,359 @@ def test_converter_with_function_calling():
|
||||
model=SimpleModel,
|
||||
instructions="Convert this text.",
|
||||
)
|
||||
converter._create_instructor = Mock(return_value=instructor)
|
||||
|
||||
with patch.object(converter, '_create_instructor', return_value=instructor):
|
||||
output = converter.to_pydantic()
|
||||
|
||||
output = converter.to_pydantic()
|
||||
|
||||
assert isinstance(output, SimpleModel)
|
||||
assert output.name == "Eve"
|
||||
assert output.age == 35
|
||||
assert isinstance(output, SimpleModel)
|
||||
assert output.name == "Eve"
|
||||
assert output.age == 35
|
||||
instructor.to_pydantic.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_model_description_union_field():
|
||||
def test_generate_model_description_union_field() -> None:
|
||||
class UnionModel(BaseModel):
|
||||
field: int | str | None
|
||||
|
||||
description = generate_model_description(UnionModel)
|
||||
expected_description = '{\n "field": int | str | None\n}'
|
||||
assert description == expected_description
|
||||
|
||||
def test_internal_instructor_with_openai_provider() -> None:
|
||||
"""Test InternalInstructor with OpenAI provider using registry pattern."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with OpenAI provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "gpt-4o"
|
||||
mock_llm.provider = "openai"
|
||||
|
||||
# Mock instructor client
|
||||
mock_client = Mock()
|
||||
mock_client.chat.completions.create.return_value = SimpleModel(name="Test", age=25)
|
||||
|
||||
# Patch the instructor import at the method level
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
instructor = InternalInstructor(
|
||||
content="Test content",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
|
||||
result = instructor.to_pydantic()
|
||||
|
||||
assert isinstance(result, SimpleModel)
|
||||
assert result.name == "Test"
|
||||
assert result.age == 25
|
||||
# Verify the method was called with the correct LLM
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
|
||||
def test_internal_instructor_with_anthropic_provider() -> None:
|
||||
"""Test InternalInstructor with Anthropic provider using registry pattern."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with Anthropic provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "claude-3-5-sonnet-20241022"
|
||||
mock_llm.provider = "anthropic"
|
||||
|
||||
# Mock instructor client
|
||||
mock_client = Mock()
|
||||
mock_client.chat.completions.create.return_value = SimpleModel(name="Bob", age=25)
|
||||
|
||||
# Patch the instructor import at the method level
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
instructor = InternalInstructor(
|
||||
content="Name: Bob, Age: 25",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
|
||||
result = instructor.to_pydantic()
|
||||
|
||||
assert isinstance(result, SimpleModel)
|
||||
assert result.name == "Bob"
|
||||
assert result.age == 25
|
||||
# Verify the method was called with the correct LLM
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
|
||||
def test_factory_pattern_registry_extensibility() -> None:
|
||||
"""Test that the factory pattern registry works with different providers."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Test with OpenAI provider
|
||||
mock_llm_openai = Mock()
|
||||
mock_llm_openai.is_litellm = False
|
||||
mock_llm_openai.model = "gpt-4o-mini"
|
||||
mock_llm_openai.provider = "openai"
|
||||
|
||||
mock_client_openai = Mock()
|
||||
mock_client_openai.chat.completions.create.return_value = SimpleModel(name="Alice", age=30)
|
||||
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client_openai
|
||||
|
||||
instructor_openai = InternalInstructor(
|
||||
content="Name: Alice, Age: 30",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm_openai
|
||||
)
|
||||
|
||||
result_openai = instructor_openai.to_pydantic()
|
||||
|
||||
assert isinstance(result_openai, SimpleModel)
|
||||
assert result_openai.name == "Alice"
|
||||
assert result_openai.age == 30
|
||||
|
||||
# Test with Anthropic provider
|
||||
mock_llm_anthropic = Mock()
|
||||
mock_llm_anthropic.is_litellm = False
|
||||
mock_llm_anthropic.model = "claude-3-5-sonnet-20241022"
|
||||
mock_llm_anthropic.provider = "anthropic"
|
||||
|
||||
mock_client_anthropic = Mock()
|
||||
mock_client_anthropic.chat.completions.create.return_value = SimpleModel(name="Bob", age=25)
|
||||
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client_anthropic
|
||||
|
||||
instructor_anthropic = InternalInstructor(
|
||||
content="Name: Bob, Age: 25",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm_anthropic
|
||||
)
|
||||
|
||||
result_anthropic = instructor_anthropic.to_pydantic()
|
||||
|
||||
assert isinstance(result_anthropic, SimpleModel)
|
||||
assert result_anthropic.name == "Bob"
|
||||
assert result_anthropic.age == 25
|
||||
|
||||
# Test with Bedrock provider
|
||||
mock_llm_bedrock = Mock()
|
||||
mock_llm_bedrock.is_litellm = False
|
||||
mock_llm_bedrock.model = "claude-3-5-sonnet-20241022"
|
||||
mock_llm_bedrock.provider = "bedrock"
|
||||
|
||||
mock_client_bedrock = Mock()
|
||||
mock_client_bedrock.chat.completions.create.return_value = SimpleModel(name="Charlie", age=35)
|
||||
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client_bedrock
|
||||
|
||||
instructor_bedrock = InternalInstructor(
|
||||
content="Name: Charlie, Age: 35",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm_bedrock
|
||||
)
|
||||
|
||||
result_bedrock = instructor_bedrock.to_pydantic()
|
||||
|
||||
assert isinstance(result_bedrock, SimpleModel)
|
||||
assert result_bedrock.name == "Charlie"
|
||||
assert result_bedrock.age == 35
|
||||
|
||||
# Test with Google provider
|
||||
mock_llm_google = Mock()
|
||||
mock_llm_google.is_litellm = False
|
||||
mock_llm_google.model = "gemini-1.5-flash"
|
||||
mock_llm_google.provider = "google"
|
||||
|
||||
mock_client_google = Mock()
|
||||
mock_client_google.chat.completions.create.return_value = SimpleModel(name="Diana", age=28)
|
||||
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client_google
|
||||
|
||||
instructor_google = InternalInstructor(
|
||||
content="Name: Diana, Age: 28",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm_google
|
||||
)
|
||||
|
||||
result_google = instructor_google.to_pydantic()
|
||||
|
||||
assert isinstance(result_google, SimpleModel)
|
||||
assert result_google.name == "Diana"
|
||||
assert result_google.age == 28
|
||||
|
||||
# Test with Azure provider
|
||||
mock_llm_azure = Mock()
|
||||
mock_llm_azure.is_litellm = False
|
||||
mock_llm_azure.model = "gpt-4o"
|
||||
mock_llm_azure.provider = "azure"
|
||||
|
||||
mock_client_azure = Mock()
|
||||
mock_client_azure.chat.completions.create.return_value = SimpleModel(name="Eve", age=32)
|
||||
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client_azure
|
||||
|
||||
instructor_azure = InternalInstructor(
|
||||
content="Name: Eve, Age: 32",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm_azure
|
||||
)
|
||||
|
||||
result_azure = instructor_azure.to_pydantic()
|
||||
|
||||
assert isinstance(result_azure, SimpleModel)
|
||||
assert result_azure.name == "Eve"
|
||||
assert result_azure.age == 32
|
||||
|
||||
|
||||
def test_internal_instructor_with_bedrock_provider() -> None:
|
||||
"""Test InternalInstructor with AWS Bedrock provider using registry pattern."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with Bedrock provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "claude-3-5-sonnet-20241022"
|
||||
mock_llm.provider = "bedrock"
|
||||
|
||||
# Mock instructor client
|
||||
mock_client = Mock()
|
||||
mock_client.chat.completions.create.return_value = SimpleModel(name="Charlie", age=35)
|
||||
|
||||
# Patch the instructor import at the method level
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
instructor = InternalInstructor(
|
||||
content="Name: Charlie, Age: 35",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
|
||||
result = instructor.to_pydantic()
|
||||
|
||||
assert isinstance(result, SimpleModel)
|
||||
assert result.name == "Charlie"
|
||||
assert result.age == 35
|
||||
# Verify the method was called with the correct LLM
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
|
||||
def test_internal_instructor_with_gemini_provider() -> None:
|
||||
"""Test InternalInstructor with Google Gemini provider using registry pattern."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with Gemini provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "gemini-1.5-flash"
|
||||
mock_llm.provider = "google"
|
||||
|
||||
# Mock instructor client
|
||||
mock_client = Mock()
|
||||
mock_client.chat.completions.create.return_value = SimpleModel(name="Diana", age=28)
|
||||
|
||||
# Patch the instructor import at the method level
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
instructor = InternalInstructor(
|
||||
content="Name: Diana, Age: 28",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
|
||||
result = instructor.to_pydantic()
|
||||
|
||||
assert isinstance(result, SimpleModel)
|
||||
assert result.name == "Diana"
|
||||
assert result.age == 28
|
||||
# Verify the method was called with the correct LLM
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
|
||||
def test_internal_instructor_with_azure_provider() -> None:
|
||||
"""Test InternalInstructor with Azure OpenAI provider using registry pattern."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with Azure provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "gpt-4o"
|
||||
mock_llm.provider = "azure"
|
||||
|
||||
# Mock instructor client
|
||||
mock_client = Mock()
|
||||
mock_client.chat.completions.create.return_value = SimpleModel(name="Eve", age=32)
|
||||
|
||||
# Patch the instructor import at the method level
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
instructor = InternalInstructor(
|
||||
content="Name: Eve, Age: 32",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
|
||||
result = instructor.to_pydantic()
|
||||
|
||||
assert isinstance(result, SimpleModel)
|
||||
assert result.name == "Eve"
|
||||
assert result.age == 32
|
||||
# Verify the method was called with the correct LLM
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
|
||||
def test_internal_instructor_unsupported_provider() -> None:
|
||||
"""Test InternalInstructor with unsupported provider raises appropriate error."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with unsupported provider
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "unsupported-model"
|
||||
mock_llm.provider = "unsupported"
|
||||
|
||||
# Mock the _create_instructor_client method to raise an error for unsupported providers
|
||||
with patch.object(InternalInstructor, '_create_instructor_client') as mock_create_client:
|
||||
mock_create_client.side_effect = Exception("Unsupported provider: unsupported")
|
||||
|
||||
# This should raise an error when trying to create the instructor client
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instructor = InternalInstructor(
|
||||
content="Test content",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
instructor.to_pydantic()
|
||||
|
||||
# Verify it's the expected error
|
||||
assert "Unsupported provider" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_internal_instructor_real_unsupported_provider() -> None:
|
||||
"""Test InternalInstructor with real unsupported provider using actual instructor library."""
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
# Mock LLM with unsupported provider that would actually fail with instructor
|
||||
mock_llm = Mock()
|
||||
mock_llm.is_litellm = False
|
||||
mock_llm.model = "unsupported-model"
|
||||
mock_llm.provider = "unsupported"
|
||||
|
||||
# This should raise a ConfigurationError from the real instructor library
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instructor = InternalInstructor(
|
||||
content="Test content",
|
||||
model=SimpleModel,
|
||||
llm=mock_llm
|
||||
)
|
||||
instructor.to_pydantic()
|
||||
|
||||
# Verify it's a configuration error about unsupported provider
|
||||
assert "Unsupported provider" in str(exc_info.value) or "unsupported" in str(exc_info.value).lower()
|
||||
|
||||
Reference in New Issue
Block a user