mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: retrieve function_calling_llm from registered LLMs in CrewBase
This commit is contained in:
@@ -137,13 +137,11 @@ def CrewBase(cls: T) -> T:
|
||||
all_functions, "is_cache_handler"
|
||||
)
|
||||
callbacks = self._filter_functions(all_functions, "is_callback")
|
||||
agents = self._filter_functions(all_functions, "is_agent")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name,
|
||||
agent_info,
|
||||
agents,
|
||||
llms,
|
||||
tool_functions,
|
||||
cache_handler_functions,
|
||||
@@ -154,7 +152,6 @@ def CrewBase(cls: T) -> T:
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: Dict[str, Any],
|
||||
agents: Dict[str, Callable],
|
||||
llms: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
cache_handler_functions: Dict[str, Callable],
|
||||
@@ -172,9 +169,10 @@ def CrewBase(cls: T) -> T:
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
self.agents_config[agent_name]["function_calling_llm"] = agents[
|
||||
function_calling_llm
|
||||
]()
|
||||
try:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[
|
||||
|
||||
@@ -8,6 +8,7 @@ researcher:
|
||||
developments in {topic}. Known for your ability to find the most relevant
|
||||
information and present it in a clear and concise manner.
|
||||
verbose: true
|
||||
function_calling_llm: "local_llm"
|
||||
|
||||
reporting_analyst:
|
||||
role: >
|
||||
@@ -18,4 +19,5 @@ reporting_analyst:
|
||||
You're a meticulous analyst with a keen eye for detail. You're known for
|
||||
your ability to turn complex data into clear and concise reports, making
|
||||
it easy for others to understand and act on the information you provide.
|
||||
verbose: true
|
||||
verbose: true
|
||||
function_calling_llm: "online_llm"
|
||||
@@ -31,6 +31,13 @@ class InternalCrew:
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
@llm
|
||||
def local_llm(self):
|
||||
return LLM(
|
||||
model='openai/model_name',
|
||||
api_key="None",
|
||||
base_url="http://xxx.xxx.xxx.xxx:8000/v1")
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(config=self.agents_config["researcher"])
|
||||
@@ -105,6 +112,20 @@ def test_task_name():
|
||||
), "Custom task name is not being set as expected"
|
||||
|
||||
|
||||
def test_agent_function_calling_llm():
|
||||
crew = InternalCrew()
|
||||
llm = crew.local_llm()
|
||||
obj_llm_agent = crew.researcher()
|
||||
assert (
|
||||
obj_llm_agent.function_calling_llm is llm
|
||||
), "agent's function_calling_llm is incorrect"
|
||||
|
||||
str_llm_agent = crew.reporting_analyst()
|
||||
assert (
|
||||
str_llm_agent.function_calling_llm.model == "online_llm"
|
||||
), "agent's function_calling_llm is incorrect"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_before_kickoff_modification():
|
||||
crew = InternalCrew()
|
||||
|
||||
Reference in New Issue
Block a user