fix: retrieve function_calling_llm from registered LLMs in CrewBase

This commit is contained in:
sakunkun
2025-03-11 11:40:33 +00:00
parent 41a670166a
commit 313038882c
3 changed files with 28 additions and 7 deletions

View File

@@ -137,13 +137,11 @@ def CrewBase(cls: T) -> T:
all_functions, "is_cache_handler" all_functions, "is_cache_handler"
) )
callbacks = self._filter_functions(all_functions, "is_callback") callbacks = self._filter_functions(all_functions, "is_callback")
agents = self._filter_functions(all_functions, "is_agent")
for agent_name, agent_info in self.agents_config.items(): for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables( self._map_agent_variables(
agent_name, agent_name,
agent_info, agent_info,
agents,
llms, llms,
tool_functions, tool_functions,
cache_handler_functions, cache_handler_functions,
@@ -154,7 +152,6 @@ def CrewBase(cls: T) -> T:
self, self,
agent_name: str, agent_name: str,
agent_info: Dict[str, Any], agent_info: Dict[str, Any],
agents: Dict[str, Callable],
llms: Dict[str, Callable], llms: Dict[str, Callable],
tool_functions: Dict[str, Callable], tool_functions: Dict[str, Callable],
cache_handler_functions: Dict[str, Callable], cache_handler_functions: Dict[str, Callable],
@@ -172,9 +169,10 @@ def CrewBase(cls: T) -> T:
] ]
if function_calling_llm := agent_info.get("function_calling_llm"): if function_calling_llm := agent_info.get("function_calling_llm"):
self.agents_config[agent_name]["function_calling_llm"] = agents[ try:
function_calling_llm self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]()
]() except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm
if step_callback := agent_info.get("step_callback"): if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[ self.agents_config[agent_name]["step_callback"] = callbacks[

View File

@@ -8,6 +8,7 @@ researcher:
developments in {topic}. Known for your ability to find the most relevant developments in {topic}. Known for your ability to find the most relevant
information and present it in a clear and concise manner. information and present it in a clear and concise manner.
verbose: true verbose: true
function_calling_llm: "local_llm"
reporting_analyst: reporting_analyst:
role: > role: >
@@ -19,3 +20,4 @@ reporting_analyst:
your ability to turn complex data into clear and concise reports, making your ability to turn complex data into clear and concise reports, making
it easy for others to understand and act on the information you provide. it easy for others to understand and act on the information you provide.
verbose: true verbose: true
function_calling_llm: "online_llm"

View File

@@ -31,6 +31,13 @@ class InternalCrew:
agents_config = "config/agents.yaml" agents_config = "config/agents.yaml"
tasks_config = "config/tasks.yaml" tasks_config = "config/tasks.yaml"
@llm
def local_llm(self):
return LLM(
model='openai/model_name',
api_key="None",
base_url="http://xxx.xxx.xxx.xxx:8000/v1")
@agent @agent
def researcher(self): def researcher(self):
return Agent(config=self.agents_config["researcher"]) return Agent(config=self.agents_config["researcher"])
@@ -105,6 +112,20 @@ def test_task_name():
), "Custom task name is not being set as expected" ), "Custom task name is not being set as expected"
def test_agent_function_calling_llm():
crew = InternalCrew()
llm = crew.local_llm()
obj_llm_agent = crew.researcher()
assert (
obj_llm_agent.function_calling_llm is llm
), "agent's function_calling_llm is incorrect"
str_llm_agent = crew.reporting_analyst()
assert (
str_llm_agent.function_calling_llm.model == "online_llm"
), "agent's function_calling_llm is incorrect"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_before_kickoff_modification(): def test_before_kickoff_modification():
crew = InternalCrew() crew = InternalCrew()