diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index 58d60e426..fff13cd08 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -41,6 +41,7 @@ from crewai.agents.parser import ( ) from crewai.flow.flow_trackable import FlowTrackable from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities import I18N @@ -209,7 +210,7 @@ class LiteAgent(FlowTrackable, BaseModel): def setup_llm(self): """Set up the LLM and other components after initialization.""" self.llm = create_llm(self.llm) - if not isinstance(self.llm, LLM): + if not isinstance(self.llm, BaseLLM): raise ValueError("Unable to create LLM instance") # Initialize callbacks @@ -232,7 +233,7 @@ class LiteAgent(FlowTrackable, BaseModel): elif isinstance(self.guardrail, str): from crewai.tasks.llm_guardrail import LLMGuardrail - assert isinstance(self.llm, LLM) + assert isinstance(self.llm, BaseLLM) self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm) diff --git a/tests/test_lite_agent.py b/tests/test_lite_agent.py index 6b0ca4b70..fb5d0d8be 100644 --- a/tests/test_lite_agent.py +++ b/tests/test_lite_agent.py @@ -418,3 +418,55 @@ def test_agent_output_when_guardrail_returns_base_model(): result = agent.kickoff(messages="Top 10 best players in the world?") assert result.pydantic == Player(name="Lionel Messi", country="Argentina") + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_lite_agent_with_custom_llm_and_guardrails(): + """Test that CustomLLM (inheriting from BaseLLM) works with guardrails.""" + from crewai.llms.base_llm import BaseLLM + + class CustomLLM(BaseLLM): + def __init__(self, response="Custom response"): + super().__init__(model="custom-model") + self.response = response + self.call_count = 0 + + def call(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None): + self.call_count += 1 + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + for message in messages: + if isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + return self.response + + custom_llm = CustomLLM(response="Brazilian soccer players are the best!") + + agent = Agent( + role="Sports Analyst", + goal="Analyze soccer players", + backstory="You analyze soccer players and their performance.", + llm=custom_llm, + guardrail="Only include Brazilian players" + ) + + result = agent.kickoff("Tell me about the best soccer players") + + assert custom_llm.call_count > 0 + assert "Brazilian" in result.raw + + custom_llm2 = CustomLLM(response="Original response") + + def test_guardrail(output): + return (True, "Modified by guardrail") + + agent2 = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + llm=custom_llm2, + guardrail=test_guardrail + ) + + result2 = agent2.kickoff("Test message") + assert result2.raw == "Modified by guardrail"