mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-20 06:18:29 +00:00
Compare commits
26 Commits
bugfix/add
...
feat/docli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bebf8e9eb1 | ||
|
|
abdc7133d5 | ||
|
|
9dda698b66 | ||
|
|
6faa0317b2 | ||
|
|
bc230b4edf | ||
|
|
f380f8ee23 | ||
|
|
c3d31deff6 | ||
|
|
aedaf01d19 | ||
|
|
436a458072 | ||
|
|
7885c5f906 | ||
|
|
ef7a101631 | ||
|
|
e14a49f82c | ||
|
|
0921f71fd2 | ||
|
|
10c04d54a9 | ||
|
|
356eb07d5f | ||
|
|
c2ed1f2355 | ||
|
|
76c640b985 | ||
|
|
054bc266b9 | ||
|
|
f1c9caa8ec | ||
|
|
610ea40c2d | ||
|
|
b14f6ffa59 | ||
|
|
56172ecf1d | ||
|
|
ee74ad0d6d | ||
|
|
a67ec7e37a | ||
|
|
625c21da5b | ||
|
|
04cb9afae5 |
@@ -79,6 +79,55 @@ crew = Crew(
|
|||||||
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
|
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Here's another example with the `CrewDoclingSource`
|
||||||
|
```python Code
|
||||||
|
from crewai import LLM, Agent, Crew, Process, Task
|
||||||
|
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||||
|
|
||||||
|
# Create a knowledge source
|
||||||
|
content_source = CrewDoclingSource(
|
||||||
|
file_paths=[
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking",
|
||||||
|
"https://lilianweng.github.io/posts/2024-07-07-hallucination",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an LLM with a temperature of 0 to ensure deterministic outputs
|
||||||
|
llm = LLM(model="gpt-4o-mini", temperature=0)
|
||||||
|
|
||||||
|
# Create an agent with the knowledge store
|
||||||
|
agent = Agent(
|
||||||
|
role="About papers",
|
||||||
|
goal="You know everything about the papers.",
|
||||||
|
backstory="""You are a master at understanding papers and their content.""",
|
||||||
|
verbose=True,
|
||||||
|
allow_delegation=False,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
task = Task(
|
||||||
|
description="Answer the following questions about the papers: {question}",
|
||||||
|
expected_output="An answer to the question.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=True,
|
||||||
|
process=Process.sequential,
|
||||||
|
knowledge_sources=[
|
||||||
|
content_source
|
||||||
|
], # Enable knowledge by adding the sources here. You can also add more sources to the sources list.
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff(
|
||||||
|
inputs={
|
||||||
|
"question": "What is the reward hacking paper about? Be sure to provide sources."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Knowledge Configuration
|
## Knowledge Configuration
|
||||||
|
|
||||||
### Chunking Configuration
|
### Chunking Configuration
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ openpyxl = [
|
|||||||
"openpyxl>=3.1.5",
|
"openpyxl>=3.1.5",
|
||||||
]
|
]
|
||||||
mem0 = ["mem0ai>=0.1.29"]
|
mem0 = ["mem0ai>=0.1.29"]
|
||||||
|
docling = [
|
||||||
|
"docling>=2.12.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
@@ -14,17 +14,28 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Base class for knowledge sources that load content from files."""
|
"""Base class for knowledge sources that load content from files."""
|
||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
file_path: Union[Path, List[Path], str, List[str]] = Field(
|
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
..., description="The path to the file"
|
default=None,
|
||||||
|
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||||
|
)
|
||||||
|
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
|
default_factory=list, description="The path to the file"
|
||||||
)
|
)
|
||||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
|
def validate_file_path(cls, v, values):
|
||||||
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
|
if v is None and ("file_path" not in values or values.get("file_path") is None):
|
||||||
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
|
return v
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to load content."""
|
"""Post-initialization method to load content."""
|
||||||
self.safe_file_paths = self._process_file_paths()
|
self.safe_file_paths = self._process_file_paths()
|
||||||
self.validate_paths()
|
self.validate_content()
|
||||||
self.content = self.load_content()
|
self.content = self.load_content()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -32,7 +43,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def validate_paths(self):
|
def validate_content(self):
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -59,13 +70,29 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
|
|
||||||
def _process_file_paths(self) -> List[Path]:
|
def _process_file_paths(self) -> List[Path]:
|
||||||
"""Convert file_path to a list of Path objects."""
|
"""Convert file_path to a list of Path objects."""
|
||||||
paths = (
|
|
||||||
[self.file_path]
|
|
||||||
if isinstance(self.file_path, (str, Path))
|
|
||||||
else self.file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(paths, list):
|
# Check if old file_path is being used
|
||||||
raise ValueError("file_path must be a Path, str, or a list of these types")
|
if hasattr(self, "file_path") and self.file_path is not None:
|
||||||
|
self._logger.log(
|
||||||
|
"warning",
|
||||||
|
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
paths = (
|
||||||
|
[self.file_path]
|
||||||
|
if isinstance(self.file_path, (str, Path))
|
||||||
|
else self.file_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.file_paths is None:
|
||||||
|
raise ValueError("Your source must be provided with a file_paths: []")
|
||||||
|
elif isinstance(self.file_paths, list) and len(self.file_paths) == 0:
|
||||||
|
raise ValueError("Empty file_paths are not allowed")
|
||||||
|
else:
|
||||||
|
paths = (
|
||||||
|
[self.file_paths]
|
||||||
|
if isinstance(self.file_paths, (str, Path))
|
||||||
|
else self.file_paths
|
||||||
|
)
|
||||||
|
|
||||||
return [self.convert_to_path(path) for path in paths]
|
return [self.convert_to_path(path) for path in paths]
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
collection_name: Optional[str] = Field(default=None)
|
collection_name: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_content(self) -> Dict[Any, str]:
|
def validate_content(self) -> Any:
|
||||||
"""Load and preprocess content from the source."""
|
"""Load and preprocess content from the source."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
120
src/crewai/knowledge/source/crew_docling_source.py
Normal file
120
src/crewai/knowledge/source/crew_docling_source.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, List, Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import InputFormat
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
from docling.exceptions import ConversionError
|
||||||
|
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||||
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
|
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class CrewDoclingSource(BaseKnowledgeSource):
|
||||||
|
"""Default Source class for converting documents to markdown or json
|
||||||
|
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_logger: Logger = Logger(verbose=True)
|
||||||
|
|
||||||
|
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||||
|
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||||
|
chunks: List[str] = Field(default_factory=list)
|
||||||
|
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||||
|
content: List[DoclingDocument] = Field(default_factory=list)
|
||||||
|
document_converter: DocumentConverter = Field(
|
||||||
|
default_factory=lambda: DocumentConverter(
|
||||||
|
allowed_formats=[
|
||||||
|
InputFormat.MD,
|
||||||
|
InputFormat.ASCIIDOC,
|
||||||
|
InputFormat.PDF,
|
||||||
|
InputFormat.DOCX,
|
||||||
|
InputFormat.HTML,
|
||||||
|
InputFormat.IMAGE,
|
||||||
|
InputFormat.XLSX,
|
||||||
|
InputFormat.PPTX,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_post_init(self, _) -> None:
|
||||||
|
if self.file_path:
|
||||||
|
self._logger.log(
|
||||||
|
"warning",
|
||||||
|
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
self.file_paths = self.file_path
|
||||||
|
self.safe_file_paths = self.validate_content()
|
||||||
|
self.content = self._load_content()
|
||||||
|
|
||||||
|
def _load_content(self) -> List[DoclingDocument]:
|
||||||
|
try:
|
||||||
|
return self._convert_source_to_docling_documents()
|
||||||
|
except ConversionError as e:
|
||||||
|
self._logger.log(
|
||||||
|
"error",
|
||||||
|
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Error loading content: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def add(self) -> None:
|
||||||
|
if self.content is None:
|
||||||
|
return
|
||||||
|
for doc in self.content:
|
||||||
|
new_chunks_iterable = self._chunk_doc(doc)
|
||||||
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
|
self._save_documents()
|
||||||
|
|
||||||
|
def _convert_source_to_docling_documents(self) -> List[DoclingDocument]:
|
||||||
|
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||||
|
return [result.document for result in conv_results_iter]
|
||||||
|
|
||||||
|
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
|
||||||
|
chunker = HierarchicalChunker()
|
||||||
|
for chunk in chunker.chunk(doc):
|
||||||
|
yield chunk.text
|
||||||
|
|
||||||
|
def validate_content(self) -> List[Union[Path, str]]:
|
||||||
|
processed_paths: List[Union[Path, str]] = []
|
||||||
|
for path in self.file_paths:
|
||||||
|
if isinstance(path, str):
|
||||||
|
if path.startswith(("http://", "https://")):
|
||||||
|
try:
|
||||||
|
if self._validate_url(path):
|
||||||
|
processed_paths.append(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid URL format: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||||
|
else:
|
||||||
|
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||||
|
if local_path.exists():
|
||||||
|
processed_paths.append(local_path)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File not found: {local_path}")
|
||||||
|
else:
|
||||||
|
# this is an instance of Path
|
||||||
|
processed_paths.append(path)
|
||||||
|
return processed_paths
|
||||||
|
|
||||||
|
def _validate_url(self, url: str) -> bool:
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
result.scheme in ("http", "https"),
|
||||||
|
result.netloc,
|
||||||
|
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
@@ -13,9 +13,9 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to validate content."""
|
"""Post-initialization method to validate content."""
|
||||||
self.load_content()
|
self.validate_content()
|
||||||
|
|
||||||
def load_content(self):
|
def validate_content(self):
|
||||||
"""Validate string content."""
|
"""Validate string content."""
|
||||||
if not isinstance(self.content, str):
|
if not isinstance(self.content, str):
|
||||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Test Knowledge creation and querying functionality."""
|
"""Test Knowledge creation and querying functionality."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||||
@@ -200,7 +202,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
file_source = TextFileKnowledgeSource(
|
file_source = TextFileKnowledgeSource(
|
||||||
file_path=file_path, metadata={"preference": "personal"}
|
file_paths=[file_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [file_source]
|
mock_vector_db.sources = [file_source]
|
||||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||||
@@ -242,7 +244,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
file_source = TextFileKnowledgeSource(
|
file_source = TextFileKnowledgeSource(
|
||||||
file_path=file_path, metadata={"preference": "personal"}
|
file_paths=[file_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [file_source]
|
mock_vector_db.sources = [file_source]
|
||||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||||
@@ -279,7 +281,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append((file_path, item["metadata"]))
|
file_paths.append((file_path, item["metadata"]))
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata=metadata)
|
TextFileKnowledgeSource(file_paths=[path], metadata=metadata)
|
||||||
for path, metadata in file_paths
|
for path, metadata in file_paths
|
||||||
]
|
]
|
||||||
mock_vector_db.sources = file_sources
|
mock_vector_db.sources = file_sources
|
||||||
@@ -352,7 +354,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
|
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
|
||||||
for path in file_paths
|
for path in file_paths
|
||||||
]
|
]
|
||||||
mock_vector_db.sources = file_sources
|
mock_vector_db.sources = file_sources
|
||||||
@@ -399,7 +401,7 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
|
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
|
||||||
for path in file_paths
|
for path in file_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -424,7 +426,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
|||||||
|
|
||||||
# Create a PDFKnowledgeSource
|
# Create a PDFKnowledgeSource
|
||||||
pdf_source = PDFKnowledgeSource(
|
pdf_source = PDFKnowledgeSource(
|
||||||
file_path=pdf_path, metadata={"preference": "personal"}
|
file_paths=[pdf_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [pdf_source]
|
mock_vector_db.sources = [pdf_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -461,7 +463,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create a CSVKnowledgeSource
|
# Create a CSVKnowledgeSource
|
||||||
csv_source = CSVKnowledgeSource(
|
csv_source = CSVKnowledgeSource(
|
||||||
file_path=csv_path, metadata={"preference": "personal"}
|
file_paths=[csv_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [csv_source]
|
mock_vector_db.sources = [csv_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -496,7 +498,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create a JSONKnowledgeSource
|
# Create a JSONKnowledgeSource
|
||||||
json_source = JSONKnowledgeSource(
|
json_source = JSONKnowledgeSource(
|
||||||
file_path=json_path, metadata={"preference": "personal"}
|
file_paths=[json_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [json_source]
|
mock_vector_db.sources = [json_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -529,7 +531,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create an ExcelKnowledgeSource
|
# Create an ExcelKnowledgeSource
|
||||||
excel_source = ExcelKnowledgeSource(
|
excel_source = ExcelKnowledgeSource(
|
||||||
file_path=excel_path, metadata={"preference": "personal"}
|
file_paths=[excel_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [excel_source]
|
mock_vector_db.sources = [excel_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -543,3 +545,42 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
# Assert that the correct information is retrieved
|
# Assert that the correct information is retrieved
|
||||||
assert any("30" in result["context"] for result in results)
|
assert any("30" in result["context"] for result in results)
|
||||||
mock_vector_db.query.assert_called_once()
|
mock_vector_db.query.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_docling_source(mock_vector_db):
|
||||||
|
docling_source = CrewDoclingSource(
|
||||||
|
file_paths=[
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
mock_vector_db.sources = [docling_source]
|
||||||
|
mock_vector_db.query.return_value = [
|
||||||
|
{
|
||||||
|
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||||
|
"score": 0.9,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# Perform a query
|
||||||
|
query = "What is reward hacking?"
|
||||||
|
results = mock_vector_db.query(query)
|
||||||
|
assert any("reward hacking" in result["context"].lower() for result in results)
|
||||||
|
mock_vector_db.query.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_docling_sources():
|
||||||
|
urls: List[Union[Path, str]] = [
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||||
|
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
||||||
|
]
|
||||||
|
docling_source = CrewDoclingSource(file_paths=urls)
|
||||||
|
|
||||||
|
assert docling_source.file_paths == urls
|
||||||
|
assert docling_source.content is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_docling_source_with_local_file():
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
pdf_path = current_dir / "crewai_quickstart.pdf"
|
||||||
|
docling_source = CrewDoclingSource(file_paths=[pdf_path])
|
||||||
|
assert docling_source.file_paths == [pdf_path]
|
||||||
|
assert docling_source.content is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user