This commit is contained in:
Lorenze Jay
2024-11-19 09:41:33 -08:00
parent b104404418
commit 70910dd7b4
2 changed files with 98 additions and 40 deletions

View File

@@ -27,8 +27,10 @@ class Knowledge(BaseModel):
for source in self.sources: for source in self.sources:
source.add() source.add()
except Exception as e: except Exception as e:
Logger.log( Logger(verbose=True).log(
"warning", f"Failed to add some sources during initialization: {e}" "warning",
f"Failed to init knowledge: {e}",
color="red",
) )
def query( def query(

View File

@@ -1,5 +1,6 @@
"""Test Knowledge creation and querying functionality.""" """Test Knowledge creation and querying functionality."""
import logging
from pathlib import Path from pathlib import Path
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
@@ -10,19 +11,32 @@ from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.knowledge.source.text_file_knowledge_source import TextFileKnowledgeSource from crewai.knowledge.source.text_file_knowledge_source import TextFileKnowledgeSource
import pytest
@pytest.fixture(autouse=True)
def reset_knowledge_storage():
"""Fixture to reset knowledge storage before each test."""
Knowledge().storage.reset()
yield
def test_single_short_string(): def test_single_short_string():
logging.basicConfig(level=logging.INFO)
# Create a knowledge base with a single short string # Create a knowledge base with a single short string
content = "Brandon's favorite color is blue and he likes Mexican food." content = "Brandon's favorite color is blue and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(
content=content, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[string_source]) knowledge_base = Knowledge(sources=[string_source])
# Perform a query # Perform a query
query = "What is Brandon's favorite color?" query = "What is Brandon's favorite color?"
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the results contain the expected information # # Assert that the results contain the expected information
assert any("blue" in result.lower() for result in results) assert any("blue" in result["context"].lower() for result in results)
def test_single_2k_character_string(): def test_single_2k_character_string():
@@ -49,7 +63,9 @@ def test_single_2k_character_string():
"Brandon's favorite sport is basketball, and he often plays with his friends on weekends. " "Brandon's favorite sport is basketball, and he often plays with his friends on weekends. "
"He is also a fan of the Golden State Warriors and enjoys watching their games. " "He is also a fan of the Golden State Warriors and enjoys watching their games. "
) )
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(
content=content, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[string_source]) knowledge_base = Knowledge(sources=[string_source])
# Perform a query # Perform a query
@@ -57,7 +73,7 @@ def test_single_2k_character_string():
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("inception" in result.lower() for result in results) assert any("inception" in result["context"].lower() for result in results)
def test_multiple_short_strings(): def test_multiple_short_strings():
@@ -67,7 +83,10 @@ def test_multiple_short_strings():
"Brandon has a dog named Max.", "Brandon has a dog named Max.",
"Brandon enjoys painting landscapes.", "Brandon enjoys painting landscapes.",
] ]
string_sources = [StringKnowledgeSource(content=content) for content in contents] string_sources = [
StringKnowledgeSource(content=content, metadata={"preference": "personal"})
for content in contents
]
knowledge_base = Knowledge(sources=string_sources) knowledge_base = Knowledge(sources=string_sources)
# Perform a query # Perform a query
@@ -75,7 +94,7 @@ def test_multiple_short_strings():
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("max" in result.lower() for result in results) assert any("max" in result["context"].lower() for result in results)
def test_multiple_2k_character_strings(): def test_multiple_2k_character_strings():
@@ -128,7 +147,13 @@ def test_multiple_2k_character_strings():
) )
* 2, # Repeat to ensure it's 2k characters * 2, # Repeat to ensure it's 2k characters
] ]
string_sources = [StringKnowledgeSource(content=content) for content in contents] string_sources = [
StringKnowledgeSource(content=content, metadata={"preference": "personal"})
for content in contents
]
# Reset the knowledge storage for each test
# Knowledge().storage.reset()
knowledge_base = Knowledge(sources=string_sources) knowledge_base = Knowledge(sources=string_sources)
# Perform a query # Perform a query
@@ -137,7 +162,8 @@ def test_multiple_2k_character_strings():
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"the hitchhiker's guide to the galaxy" in result.lower() for result in results "the hitchhiker's guide to the galaxy" in result["context"].lower()
for result in results
) )
@@ -148,7 +174,9 @@ def test_single_short_file(tmpdir):
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(content) f.write(content)
file_source = TextFileKnowledgeSource(file_path=file_path) file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[file_source]) knowledge_base = Knowledge(sources=[file_source])
# Perform a query # Perform a query
@@ -156,7 +184,7 @@ def test_single_short_file(tmpdir):
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("basketball" in result.lower() for result in results) assert any("basketball" in result["context"].lower() for result in results)
def test_single_2k_character_file(tmpdir): def test_single_2k_character_file(tmpdir):
@@ -187,7 +215,9 @@ def test_single_2k_character_file(tmpdir):
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(content) f.write(content)
file_source = TextFileKnowledgeSource(file_path=file_path) file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[file_source]) knowledge_base = Knowledge(sources=[file_source])
# Perform a query # Perform a query
@@ -195,32 +225,43 @@ def test_single_2k_character_file(tmpdir):
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("inception" in result.lower() for result in results) assert any("inception" in result["context"].lower() for result in results)
def test_multiple_short_files(tmpdir): def test_multiple_short_files(tmpdir):
# Create multiple short text files # Create multiple short text files
contents = [ contents = [
"Brandon lives in New York.", {
"Brandon works as a software engineer.", "content": "Brandon works as a software engineer.",
"Brandon enjoys cooking Italian food.", "metadata": {"category": "profession", "source": "occupation"},
},
{
"content": "Brandon lives in New York.",
"metadata": {"category": "city", "source": "personal"},
},
{
"content": "Brandon enjoys cooking Italian food.",
"metadata": {"category": "hobby", "source": "personal"},
},
] ]
file_paths = [] file_paths = []
for i, content in enumerate(contents): for i, item in enumerate(contents):
file_path = Path(tmpdir.join(f"file_{i}.txt")) file_path = Path(tmpdir.join(f"file_{i}.txt"))
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(content) f.write(item["content"])
file_paths.append(file_path) file_paths.append((file_path, item["metadata"]))
file_sources = [TextFileKnowledgeSource(file_path=path) for path in file_paths] file_sources = [
TextFileKnowledgeSource(file_path=path, metadata=metadata)
for path, metadata in file_paths
]
knowledge_base = Knowledge(sources=file_sources) knowledge_base = Knowledge(sources=file_sources)
# Perform a query # Perform a query
query = "Where does Brandon live?" query = "What city does he reside in?"
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("new york" in result.lower() for result in results) assert any("new york" in result["context"].lower() for result in results)
def test_multiple_2k_character_files(tmpdir): def test_multiple_2k_character_files(tmpdir):
@@ -280,7 +321,10 @@ def test_multiple_2k_character_files(tmpdir):
f.write(content) f.write(content)
file_paths.append(file_path) file_paths.append(file_path)
file_sources = [TextFileKnowledgeSource(file_path=path) for path in file_paths] file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
for path in file_paths
]
knowledge_base = Knowledge(sources=file_sources) knowledge_base = Knowledge(sources=file_sources)
# Perform a query # Perform a query
@@ -289,7 +333,8 @@ def test_multiple_2k_character_files(tmpdir):
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"the hitchhiker's guide to the galaxy" in result.lower() for result in results "the hitchhiker's guide to the galaxy" in result["context"].lower()
for result in results
) )
@@ -300,7 +345,8 @@ def test_hybrid_string_and_files(tmpdir):
"Brandon visited Paris last summer.", "Brandon visited Paris last summer.",
] ]
string_sources = [ string_sources = [
StringKnowledgeSource(content=content) for content in string_contents StringKnowledgeSource(content=content, metadata={"preference": "personal"})
for content in string_contents
] ]
# Create file sources # Create file sources
@@ -315,7 +361,10 @@ def test_hybrid_string_and_files(tmpdir):
f.write(content) f.write(content)
file_paths.append(file_path) file_paths.append(file_path)
file_sources = [TextFileKnowledgeSource(file_path=path) for path in file_paths] file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
for path in file_paths
]
# Combine string and file sources # Combine string and file sources
knowledge_base = Knowledge(sources=string_sources + file_sources) knowledge_base = Knowledge(sources=string_sources + file_sources)
@@ -325,7 +374,7 @@ def test_hybrid_string_and_files(tmpdir):
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("the alchemist" in result.lower() for result in results) assert any("the alchemist" in result["context"].lower() for result in results)
def test_pdf_knowledge_source(): def test_pdf_knowledge_source():
@@ -335,17 +384,18 @@ def test_pdf_knowledge_source():
pdf_path = current_dir / "crewai_quickstart.pdf" pdf_path = current_dir / "crewai_quickstart.pdf"
# Create a PDFKnowledgeSource # Create a PDFKnowledgeSource
pdf_source = PDFKnowledgeSource(file_path=pdf_path) pdf_source = PDFKnowledgeSource(
file_path=pdf_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[pdf_source]) knowledge_base = Knowledge(sources=[pdf_source])
# Perform a query # Perform a query
query = "How do you create a crew?" query = "How do you create a crew?"
results = knowledge_base.query(query) results = knowledge_base.query(query)
print("Results from querying PDFKnowledgeSource:", results)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"crewai create crew latest-ai-development" in result.lower() "crewai create crew latest-ai-development" in result["context"].lower()
for result in results for result in results
) )
@@ -366,7 +416,9 @@ def test_csv_knowledge_source(tmpdir):
f.write(",".join(row) + "\n") f.write(",".join(row) + "\n")
# Create a CSVKnowledgeSource # Create a CSVKnowledgeSource
csv_source = CSVKnowledgeSource(file_path=csv_path) csv_source = CSVKnowledgeSource(
file_path=csv_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[csv_source]) knowledge_base = Knowledge(sources=[csv_source])
# Perform a query # Perform a query
@@ -374,7 +426,7 @@ def test_csv_knowledge_source(tmpdir):
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("30" in result for result in results) assert any("30" in result["context"] for result in results)
def test_json_knowledge_source(tmpdir): def test_json_knowledge_source(tmpdir):
@@ -395,15 +447,17 @@ def test_json_knowledge_source(tmpdir):
json.dump(json_data, f) json.dump(json_data, f)
# Create a JSONKnowledgeSource # Create a JSONKnowledgeSource
json_source = JSONKnowledgeSource(file_path=json_path) json_source = JSONKnowledgeSource(
file_path=json_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[json_source]) knowledge_base = Knowledge(sources=[json_source])
# Perform a query # Perform a query
query = "Where does Brandon live?" query = "Where does Alice reside?"
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("New York" in result for result in results) assert any("los angeles" in result["context"].lower() for result in results)
def test_excel_knowledge_source(tmpdir): def test_excel_knowledge_source(tmpdir):
@@ -422,7 +476,9 @@ def test_excel_knowledge_source(tmpdir):
df.to_excel(excel_path, index=False) df.to_excel(excel_path, index=False)
# Create an ExcelKnowledgeSource # Create an ExcelKnowledgeSource
excel_source = ExcelKnowledgeSource(file_path=excel_path) excel_source = ExcelKnowledgeSource(
file_path=excel_path, metadata={"preference": "personal"}
)
knowledge_base = Knowledge(sources=[excel_source]) knowledge_base = Knowledge(sources=[excel_source])
# Perform a query # Perform a query
@@ -430,4 +486,4 @@ def test_excel_knowledge_source(tmpdir):
results = knowledge_base.query(query) results = knowledge_base.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("30" in result for result in results) assert any("30" in result["context"] for result in results)