diff --git a/tests/agent_test.py b/tests/agent_test.py index 796e651db..83e530427 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1618,57 +1618,71 @@ def test_agent_with_knowledge_sources(): def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) - knowledge_config = KnowledgeConfig(limit=10, score_threshold=0.5) - with patch.object(Knowledge, "query") as mock_knowledge_query: - agent = Agent( - role="Information Agent", - goal="Provide information based on knowledge sources", - backstory="You have access to specific knowledge sources.", - llm=LLM(model="gpt-4o-mini"), - knowledge_sources=[string_source], - knowledge_config=knowledge_config, - ) - task = Task( - description="What is Brandon's favorite color?", - expected_output="Brandon's favorite color.", - agent=agent, - ) - crew = Crew(agents=[agent], tasks=[task]) - crew.kickoff() + knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5) + with patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as MockKnowledge: + mock_knowledge_instance = MockKnowledge.return_value + mock_knowledge_instance.sources = [string_source] + mock_knowledge_instance.query.return_value = [{"content": content}] + with patch.object(Knowledge, "query") as mock_knowledge_query: + agent = Agent( + role="Information Agent", + goal="Provide information based on knowledge sources", + backstory="You have access to specific knowledge sources.", + llm=LLM(model="gpt-4o-mini"), + knowledge_sources=[string_source], + knowledge_config=knowledge_config, + ) + task = Task( + description="What is Brandon's favorite color?", + expected_output="Brandon's favorite color.", + agent=agent, + ) + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() - assert agent.knowledge is not None - mock_knowledge_query.assert_called_once_with( - [task.prompt()], - **knowledge_config.model_dump(), - ) + assert agent.knowledge is not None + mock_knowledge_query.assert_called_once_with( + [task.prompt()], + **knowledge_config.model_dump(), + ) def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) knowledge_config = KnowledgeConfig() - with patch.object(Knowledge, "query") as mock_knowledge_query: - agent = Agent( - role="Information Agent", - goal="Provide information based on knowledge sources", - backstory="You have access to specific knowledge sources.", - llm=LLM(model="gpt-4o-mini"), - knowledge_sources=[string_source], - knowledge_config=knowledge_config, - ) - task = Task( - description="What is Brandon's favorite color?", - expected_output="Brandon's favorite color.", - agent=agent, - ) - crew = Crew(agents=[agent], tasks=[task]) - crew.kickoff() + with patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as MockKnowledge: + mock_knowledge_instance = MockKnowledge.return_value + mock_knowledge_instance.sources = [string_source] + mock_knowledge_instance.query.return_value = [{"content": content}] + with patch.object(Knowledge, "query") as mock_knowledge_query: + string_source = StringKnowledgeSource(content=content) + knowledge_config = KnowledgeConfig() + agent = Agent( + role="Information Agent", + goal="Provide information based on knowledge sources", + backstory="You have access to specific knowledge sources.", + llm=LLM(model="gpt-4o-mini"), + knowledge_sources=[string_source], + knowledge_config=knowledge_config, + ) + task = Task( + description="What is Brandon's favorite color?", + expected_output="Brandon's favorite color.", + agent=agent, + ) + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() - assert agent.knowledge is not None - mock_knowledge_query.assert_called_once_with( - [task.prompt()], - **knowledge_config.model_dump(), - ) + assert agent.knowledge is not None + mock_knowledge_query.assert_called_once_with( + [task.prompt()], + **knowledge_config.model_dump(), + ) @pytest.mark.vcr(filter_headers=["authorization"])