diff --git a/src/crewai/agent.py b/src/crewai/agent.py index b92a83d14..d8b6860e3 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -142,8 +142,13 @@ class Agent(BaseAgent): self.embedder = crew_embedder if self.knowledge_sources: - from crewai.utilities import sanitize_collection_name - knowledge_agent_name = sanitize_collection_name(self.role) + try: + from crewai.utilities import sanitize_collection_name + knowledge_agent_name = sanitize_collection_name(self.role) + except Exception as e: + self._logger.warning(f"Error sanitizing collection name: {e}") + knowledge_agent_name = "default_agent" + if isinstance(self.knowledge_sources, list) and all( isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources ): diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index f2badd2d4..946c4390a 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -7,7 +7,7 @@ from .parser import YamlParser from .printer import Printer from .prompts import Prompts from .rpm_controller import RPMController -from .string_utils import sanitize_collection_name +from .string_utils import sanitize_collection_name, is_ipv4_pattern from .exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) @@ -27,4 +27,5 @@ __all__ = [ "LLMContextLengthExceededException", "EmbeddingConfigurator", "sanitize_collection_name", + "is_ipv4_pattern", ] diff --git a/src/crewai/utilities/string_utils.py b/src/crewai/utilities/string_utils.py index 6da07b20d..05a637383 100644 --- a/src/crewai/utilities/string_utils.py +++ b/src/crewai/utilities/string_utils.py @@ -84,6 +84,28 @@ def interpolate_only( from typing import Optional +# Constants for ChromaDB collection name requirements +MIN_LENGTH = 3 +MAX_LENGTH = 63 +DEFAULT_COLLECTION = "default_collection" + +# Compiled regex patterns for better performance +INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") +IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$") + + +def is_ipv4_pattern(name: str) -> bool: + """ + Check if a string matches an IPv4 address pattern. + + Args: + name: The string to check + + Returns: + True if the string matches an IPv4 pattern, False otherwise + """ + return bool(IPV4_PATTERN.match(name)) + def sanitize_collection_name(name: Optional[str]) -> str: """ @@ -101,10 +123,14 @@ def sanitize_collection_name(name: Optional[str]) -> str: A sanitized collection name that meets ChromaDB requirements """ if not name: - return "default_collection" + return DEFAULT_COLLECTION + + # Handle IPv4 pattern + if is_ipv4_pattern(name): + name = f"ip_{name}" # Replace spaces and invalid characters with underscores - sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name) + sanitized = INVALID_CHARS_PATTERN.sub("_", name) # Ensure it starts with alphanumeric if not sanitized[0].isalnum(): @@ -114,12 +140,12 @@ def sanitize_collection_name(name: Optional[str]) -> str: if not sanitized[-1].isalnum(): sanitized = sanitized[:-1] + "z" - # Ensure length is between 3-63 characters - if len(sanitized) < 3: + # Ensure length is between MIN_LENGTH-MAX_LENGTH characters + if len(sanitized) < MIN_LENGTH: # Add padding with alphanumeric character at the end - sanitized = sanitized + "x" * (3 - len(sanitized)) - if len(sanitized) > 63: - sanitized = sanitized[:63] + sanitized = sanitized + "x" * (MIN_LENGTH - len(sanitized)) + if len(sanitized) > MAX_LENGTH: + sanitized = sanitized[:MAX_LENGTH] # Ensure it still ends with alphanumeric after truncation if not sanitized[-1].isalnum(): sanitized = sanitized[:-1] + "z" diff --git a/tests/utilities/test_string_utils.py b/tests/utilities/test_string_utils.py index 04a0dcb56..2e2cf2e0c 100644 --- a/tests/utilities/test_string_utils.py +++ b/tests/utilities/test_string_utils.py @@ -3,7 +3,12 @@ from typing import Any, Dict, List, Union import pytest -from crewai.utilities.string_utils import interpolate_only, sanitize_collection_name +from crewai.utilities import is_ipv4_pattern, sanitize_collection_name +from crewai.utilities.string_utils import ( + MAX_LENGTH, + MIN_LENGTH, + interpolate_only, +) class TestInterpolateOnly: @@ -193,7 +198,7 @@ class TestStringUtils(unittest.TestCase): """Test sanitizing a very long collection name.""" long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name" sanitized = sanitize_collection_name(long_name) - self.assertLessEqual(len(sanitized), 63) + self.assertLessEqual(len(sanitized), MAX_LENGTH) self.assertTrue(sanitized[0].isalnum()) self.assertTrue(sanitized[-1].isalnum()) self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized)) @@ -210,7 +215,7 @@ class TestStringUtils(unittest.TestCase): """Test sanitizing a very short name.""" short_name = "A" sanitized = sanitize_collection_name(short_name) - self.assertGreaterEqual(len(sanitized), 3) + self.assertGreaterEqual(len(sanitized), MIN_LENGTH) self.assertTrue(sanitized[0].isalnum()) self.assertTrue(sanitized[-1].isalnum()) @@ -226,6 +231,37 @@ class TestStringUtils(unittest.TestCase): sanitized = sanitize_collection_name(None) self.assertEqual(sanitized, "default_collection") + def test_sanitize_collection_name_ipv4_pattern(self): + """Test sanitizing an IPv4 address.""" + ipv4 = "192.168.1.1" + sanitized = sanitize_collection_name(ipv4) + self.assertTrue(sanitized.startswith("ip_")) + self.assertTrue(sanitized[0].isalnum()) + self.assertTrue(sanitized[-1].isalnum()) + self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized)) + + def test_is_ipv4_pattern(self): + """Test IPv4 pattern detection.""" + self.assertTrue(is_ipv4_pattern("192.168.1.1")) + self.assertFalse(is_ipv4_pattern("not.an.ip.address")) + + def test_sanitize_collection_name_properties(self): + """Test that sanitized collection names always meet ChromaDB requirements.""" + test_cases = [ + "A" * 100, # Very long name + "_start_with_underscore", + "end_with_underscore_", + "contains@special#characters", + "192.168.1.1", # IPv4 address + "a" * 2, # Too short + ] + for test_case in test_cases: + sanitized = sanitize_collection_name(test_case) + self.assertGreaterEqual(len(sanitized), MIN_LENGTH) + self.assertLessEqual(len(sanitized), MAX_LENGTH) + self.assertTrue(sanitized[0].isalnum()) + self.assertTrue(sanitized[-1].isalnum()) + if __name__ == "__main__": unittest.main()