diff --git a/src/crewai/knowledge/source/string_knowledge_source.py b/src/crewai/knowledge/source/string_knowledge_source.py index f8905407b..94caeeb58 100644 --- a/src/crewai/knowledge/source/string_knowledge_source.py +++ b/src/crewai/knowledge/source/string_knowledge_source.py @@ -3,35 +3,73 @@ from typing import List, Optional from pydantic import Field from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.utilities.logger import Logger class StringKnowledgeSource(BaseKnowledgeSource): """A knowledge source that stores and queries plain text content using embeddings.""" + _logger: Logger = Logger(verbose=True) content: str = Field(...) collection_name: Optional[str] = Field(default=None) - def model_post_init(self, _): - """Post-initialization method to validate content and initialize storage.""" - self.validate_content() - if self.storage is None: - from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage - self.storage = KnowledgeStorage(collection_name=self.collection_name) - self.storage.initialize_knowledge_storage() + def model_post_init(self, _) -> None: + """Post-initialization method to validate content and initialize storage. + + This method is called after the model is initialized to perform content validation + and set up the knowledge storage system. It ensures that: + 1. The content is a valid string + 2. The storage system is properly initialized + + Raises: + ValueError: If content validation fails or storage initialization fails + """ + try: + self.validate_content() + if self.storage is None: + self.storage = KnowledgeStorage(collection_name=self.collection_name) + self.storage.initialize_knowledge_storage() + except Exception as e: + error_msg = f"Failed to initialize knowledge storage: {str(e)}" + self._logger.log("error", error_msg, "red") + raise ValueError(error_msg) - def validate_content(self): - """Validate string content.""" - if not isinstance(self.content, str): - raise ValueError("StringKnowledgeSource only accepts string content") + def validate_content(self) -> None: + """Validate that the content is a valid string. + + Raises: + ValueError: If content is not a string or is empty + """ + if not isinstance(self.content, str) or not self.content.strip(): + error_msg = "StringKnowledgeSource only accepts string content" + self._logger.log("error", error_msg, "red") + raise ValueError(error_msg) def add(self) -> None: - """Add string content to the knowledge source, chunk it, compute embeddings, and save them.""" + """Add string content to the knowledge source, chunk it, compute embeddings, and save them. + + This method processes the content by: + 1. Chunking the text into smaller pieces + 2. Adding the chunks to the source + 3. Computing embeddings and saving them + + Raises: + ValueError: If storage is not initialized when trying to save documents + """ new_chunks = self._chunk_text(self.content) self.chunks.extend(new_chunks) self._save_documents() def _chunk_text(self, text: str) -> List[str]: - """Utility method to split text into chunks.""" + """Split text into chunks based on chunk_size and chunk_overlap. + + Args: + text: The text to split into chunks + + Returns: + List[str]: List of text chunks + """ return [ text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size - self.chunk_overlap) diff --git a/tests/knowledge/knowledge_test.py b/tests/knowledge/knowledge_test.py index 696830875..0ef8bb46a 100644 --- a/tests/knowledge/knowledge_test.py +++ b/tests/knowledge/knowledge_test.py @@ -5,6 +5,7 @@ from typing import List, Union from unittest.mock import patch import pytest +from pydantic import ValidationError from crewai.knowledge.source.crew_docling_source import CrewDoclingSource from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource @@ -37,26 +38,40 @@ def reset_knowledge_storage(mock_vector_db): yield -def test_string_knowledge_source(mock_vector_db): - """Test StringKnowledgeSource with simple text content.""" - content = "Users name is John. He is 30 years old and lives in San Francisco." - string_source = StringKnowledgeSource(content=content) - mock_vector_db.sources = [string_source] - mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] +class TestStringKnowledgeSource: + def test_initialization(self, mock_vector_db): + """Test basic initialization of StringKnowledgeSource.""" + content = "Users name is John. He is 30 years old and lives in San Francisco." + string_source = StringKnowledgeSource(content=content) + assert string_source.content == content + assert string_source.storage is not None - # Test initialization - assert string_source.content == content - - # Test adding content - string_source.add() - assert len(string_source.chunks) > 0 - - # Test querying - query = "Where does John live?" - results = mock_vector_db.query(query) - assert len(results) > 0 - assert "San Francisco" in results[0]["context"] - mock_vector_db.query.assert_called_once() + def test_add_and_query(self, mock_vector_db): + """Test adding content and querying.""" + content = "Users name is John. He is 30 years old and lives in San Francisco." + string_source = StringKnowledgeSource(content=content) + string_source.storage = mock_vector_db + + mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] + + string_source.add() + assert len(string_source.chunks) > 0 + + query = "Where does John live?" + results = mock_vector_db.query(query) + assert len(results) > 0 + assert "San Francisco" in results[0]["context"] + mock_vector_db.query.assert_called_once() + + def test_empty_content(self, mock_vector_db): + """Test that empty content raises ValueError.""" + with pytest.raises(ValueError, match="StringKnowledgeSource only accepts string content"): + StringKnowledgeSource(content="") + + def test_non_string_content(self, mock_vector_db): + """Test that non-string content raises ValidationError.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + StringKnowledgeSource(content=123) def test_single_short_string(mock_vector_db):