added tool for docling support

This commit is contained in:
Lorenze Jay
2024-12-15 22:15:49 -08:00
parent 6d7c1b0743
commit 04cb9afae5
2 changed files with 112 additions and 9 deletions

View File

@@ -11,6 +11,7 @@ from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.knowledge.source.text_file_knowledge_source import TextFileKnowledgeSource
from crewai.knowledge.source.docling_source import DoclingSource
@pytest.fixture(autouse=True)
@@ -200,7 +201,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
f.write(content)
file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"}
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -242,7 +243,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
f.write(content)
file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"}
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -279,7 +280,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
file_paths.append((file_path, item["metadata"]))
file_sources = [
TextFileKnowledgeSource(file_path=path, metadata=metadata)
TextFileKnowledgeSource(file_paths=[path], metadata=metadata)
for path, metadata in file_paths
]
mock_vector_db.sources = file_sources
@@ -352,7 +353,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
file_paths.append(file_path)
file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
for path in file_paths
]
mock_vector_db.sources = file_sources
@@ -399,7 +400,7 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
file_paths.append(file_path)
file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
for path in file_paths
]
@@ -424,7 +425,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Create a PDFKnowledgeSource
pdf_source = PDFKnowledgeSource(
file_path=pdf_path, metadata={"preference": "personal"}
file_paths=[pdf_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [
@@ -461,7 +462,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
# Create a CSVKnowledgeSource
csv_source = CSVKnowledgeSource(
file_path=csv_path, metadata={"preference": "personal"}
file_paths=[csv_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [
@@ -496,7 +497,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
# Create a JSONKnowledgeSource
json_source = JSONKnowledgeSource(
file_path=json_path, metadata={"preference": "personal"}
file_paths=[json_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [
@@ -529,7 +530,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
# Create an ExcelKnowledgeSource
excel_source = ExcelKnowledgeSource(
file_path=excel_path, metadata={"preference": "personal"}
file_paths=[excel_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [
@@ -543,3 +544,23 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
mock_vector_db.query.assert_called_once()
def test_docling_source(mock_vector_db):
docling_source = DoclingSource(
file_paths=[
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
],
)
mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
]
# Perform a query
query = "What is reward hacking?"
results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results)
mock_vector_db.query.assert_called_once()