fix: delegate collection name sanitization to knowledge store

This commit is contained in:
lucasgomide
2025-03-25 15:06:00 -03:00
committed by Lucas Gomide
parent df25703cc2
commit 6b14ffcffb
9 changed files with 563 additions and 164 deletions

View File

@@ -142,20 +142,13 @@ class Agent(BaseAgent):
self.embedder = crew_embedder
if self.knowledge_sources:
try:
from crewai.utilities import sanitize_collection_name
knowledge_agent_name = sanitize_collection_name(self.role)
except Exception as e:
self._logger.warning(f"Error sanitizing collection name: {e}")
knowledge_agent_name = "default_agent"
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder=self.embedder,
collection_name=knowledge_agent_name,
collection_name=self.role,
storage=self.knowledge_storage or None,
)
except (TypeError, ValueError) as e:

View File

@@ -98,8 +98,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
else "knowledge"
)
if self.app:
from crewai.utilities.chromadb import sanitize_collection_name
self.collection = self.app.get_or_create_collection(
name=collection_name, embedding_function=self.embedder
name=sanitize_collection_name(collection_name),
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")

View File

@@ -7,7 +7,6 @@ from .parser import YamlParser
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .string_utils import sanitize_collection_name, is_ipv4_pattern
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
@@ -26,6 +25,4 @@ __all__ = [
"YamlParser",
"LLMContextLengthExceededException",
"EmbeddingConfigurator",
"sanitize_collection_name",
"is_ipv4_pattern",
]

View File

@@ -0,0 +1,62 @@
import re
from typing import Optional
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(name: Optional[str]) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > MAX_COLLECTION_LENGTH:
sanitized = sanitized[:MAX_COLLECTION_LENGTH]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized

View File

@@ -80,74 +80,3 @@ def interpolate_only(
result = result.replace(placeholder, value)
return result
from typing import Optional
# Constants for ChromaDB collection name requirements
MIN_LENGTH = 3
MAX_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(name: Optional[str]) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
# Handle IPv4 pattern
if is_ipv4_pattern(name):
name = f"ip_{name}"
# Replace spaces and invalid characters with underscores
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
# Ensure it starts with alphanumeric
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
# Ensure it ends with alphanumeric
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
# Ensure length is between MIN_LENGTH-MAX_LENGTH characters
if len(sanitized) < MIN_LENGTH:
# Add padding with alphanumeric character at the end
sanitized = sanitized + "x" * (MIN_LENGTH - len(sanitized))
if len(sanitized) > MAX_LENGTH:
sanitized = sanitized[:MAX_LENGTH]
# Ensure it still ends with alphanumeric after truncation
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized

View File

@@ -1621,6 +1621,38 @@ def test_agent_with_knowledge_sources():
assert "red" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters",
goal="Provide information based on knowledge sources",
backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source],
)
task = Task(
description="What is Brandon's favorite color?",
expected_output="Brandon's favorite color.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert "red" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_works_with_copy():
content = "Brandon's favorite color is red and he likes Mexican food."

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,81 @@
import unittest
from typing import Any, Dict, List, Union
import pytest
from crewai.utilities.chromadb import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
is_ipv4_pattern,
sanitize_collection_name,
)
class TestChromadbUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_special_chars(self):
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_short_name(self):
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_bad_ends(self):
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_none(self):
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
def test_sanitize_collection_name_ipv4_pattern(self):
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_is_ipv4_pattern(self):
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
def test_sanitize_collection_name_properties(self):
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())

View File

@@ -1,14 +1,8 @@
import unittest
from typing import Any, Dict, List, Union
import pytest
from crewai.utilities import is_ipv4_pattern, sanitize_collection_name
from crewai.utilities.string_utils import (
MAX_LENGTH,
MIN_LENGTH,
interpolate_only,
)
from crewai.utilities.string_utils import interpolate_only
class TestInterpolateOnly:
@@ -191,77 +185,3 @@ class TestInterpolateOnly:
interpolate_only(template, inputs)
assert "inputs dictionary cannot be empty" in str(excinfo.value).lower()
class TestStringUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_special_chars(self):
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_short_name(self):
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_bad_ends(self):
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_none(self):
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
def test_sanitize_collection_name_ipv4_pattern(self):
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_is_ipv4_pattern(self):
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
def test_sanitize_collection_name_properties(self):
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
self.assertLessEqual(len(sanitized), MAX_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
if __name__ == "__main__":
unittest.main()