diff --git a/src/crewai/tasks/hallucination_guardrail.py b/src/crewai/tasks/hallucination_guardrail.py index 3079bc243..0e5254c45 100644 --- a/src/crewai/tasks/hallucination_guardrail.py +++ b/src/crewai/tasks/hallucination_guardrail.py @@ -9,6 +9,7 @@ Classes: from typing import Any, Optional, Tuple from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM from crewai.tasks.task_output import TaskOutput from crewai.utilities.logger import Logger @@ -47,7 +48,7 @@ class HallucinationGuardrail: def __init__( self, context: str, - llm: LLM, + llm: BaseLLM, threshold: Optional[float] = None, tool_response: str = "", ): @@ -60,7 +61,7 @@ class HallucinationGuardrail: tool_response: Optional tool response information that would be used in evaluation. """ self.context = context - self.llm: LLM = llm + self.llm: BaseLLM = llm self.threshold = threshold self.tool_response = tool_response self._logger = Logger(verbose=True) diff --git a/src/crewai/tasks/llm_guardrail.py b/src/crewai/tasks/llm_guardrail.py index 2bb948075..172cdbec5 100644 --- a/src/crewai/tasks/llm_guardrail.py +++ b/src/crewai/tasks/llm_guardrail.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from crewai.agent import Agent, LiteAgentOutput from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM from crewai.task import Task from crewai.tasks.task_output import TaskOutput @@ -32,11 +33,11 @@ class LLMGuardrail: def __init__( self, description: str, - llm: LLM, + llm: BaseLLM, ): self.description = description - self.llm: LLM = llm + self.llm: BaseLLM = llm def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput: agent = Agent(