mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Address PR review: Add constants, IPv4 validation, error handling, and expanded tests
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -142,8 +142,13 @@ class Agent(BaseAgent):
|
||||
self.embedder = crew_embedder
|
||||
|
||||
if self.knowledge_sources:
|
||||
try:
|
||||
from crewai.utilities import sanitize_collection_name
|
||||
knowledge_agent_name = sanitize_collection_name(self.role)
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Error sanitizing collection name: {e}")
|
||||
knowledge_agent_name = "default_agent"
|
||||
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
|
||||
@@ -7,7 +7,7 @@ from .parser import YamlParser
|
||||
from .printer import Printer
|
||||
from .prompts import Prompts
|
||||
from .rpm_controller import RPMController
|
||||
from .string_utils import sanitize_collection_name
|
||||
from .string_utils import sanitize_collection_name, is_ipv4_pattern
|
||||
from .exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -27,4 +27,5 @@ __all__ = [
|
||||
"LLMContextLengthExceededException",
|
||||
"EmbeddingConfigurator",
|
||||
"sanitize_collection_name",
|
||||
"is_ipv4_pattern",
|
||||
]
|
||||
|
||||
@@ -84,6 +84,28 @@ def interpolate_only(
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# Constants for ChromaDB collection name requirements
|
||||
MIN_LENGTH = 3
|
||||
MAX_LENGTH = 63
|
||||
DEFAULT_COLLECTION = "default_collection"
|
||||
|
||||
# Compiled regex patterns for better performance
|
||||
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
|
||||
|
||||
def is_ipv4_pattern(name: str) -> bool:
|
||||
"""
|
||||
Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def sanitize_collection_name(name: Optional[str]) -> str:
|
||||
"""
|
||||
@@ -101,10 +123,14 @@ def sanitize_collection_name(name: Optional[str]) -> str:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return "default_collection"
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
# Handle IPv4 pattern
|
||||
if is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
# Replace spaces and invalid characters with underscores
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
# Ensure it starts with alphanumeric
|
||||
if not sanitized[0].isalnum():
|
||||
@@ -114,12 +140,12 @@ def sanitize_collection_name(name: Optional[str]) -> str:
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
# Ensure length is between 3-63 characters
|
||||
if len(sanitized) < 3:
|
||||
# Ensure length is between MIN_LENGTH-MAX_LENGTH characters
|
||||
if len(sanitized) < MIN_LENGTH:
|
||||
# Add padding with alphanumeric character at the end
|
||||
sanitized = sanitized + "x" * (3 - len(sanitized))
|
||||
if len(sanitized) > 63:
|
||||
sanitized = sanitized[:63]
|
||||
sanitized = sanitized + "x" * (MIN_LENGTH - len(sanitized))
|
||||
if len(sanitized) > MAX_LENGTH:
|
||||
sanitized = sanitized[:MAX_LENGTH]
|
||||
# Ensure it still ends with alphanumeric after truncation
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
@@ -3,7 +3,12 @@ from typing import Any, Dict, List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.string_utils import interpolate_only, sanitize_collection_name
|
||||
from crewai.utilities import is_ipv4_pattern, sanitize_collection_name
|
||||
from crewai.utilities.string_utils import (
|
||||
MAX_LENGTH,
|
||||
MIN_LENGTH,
|
||||
interpolate_only,
|
||||
)
|
||||
|
||||
|
||||
class TestInterpolateOnly:
|
||||
@@ -193,7 +198,7 @@ class TestStringUtils(unittest.TestCase):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = sanitize_collection_name(long_name)
|
||||
self.assertLessEqual(len(sanitized), 63)
|
||||
self.assertLessEqual(len(sanitized), MAX_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
@@ -210,7 +215,7 @@ class TestStringUtils(unittest.TestCase):
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = sanitize_collection_name(short_name)
|
||||
self.assertGreaterEqual(len(sanitized), 3)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
@@ -226,6 +231,37 @@ class TestStringUtils(unittest.TestCase):
|
||||
sanitized = sanitize_collection_name(None)
|
||||
self.assertEqual(sanitized, "default_collection")
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self):
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = sanitize_collection_name(ipv4)
|
||||
self.assertTrue(sanitized.startswith("ip_"))
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_is_ipv4_pattern(self):
|
||||
"""Test IPv4 pattern detection."""
|
||||
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
||||
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
||||
|
||||
def test_sanitize_collection_name_properties(self):
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = sanitize_collection_name(test_case)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
|
||||
self.assertLessEqual(len(sanitized), MAX_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user