mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 16:48:13 +00:00
fix: delegate collection name sanitization to knowledge store
This commit is contained in:
committed by
Lucas Gomide
parent
df25703cc2
commit
6b14ffcffb
@@ -142,20 +142,13 @@ class Agent(BaseAgent):
|
|||||||
self.embedder = crew_embedder
|
self.embedder = crew_embedder
|
||||||
|
|
||||||
if self.knowledge_sources:
|
if self.knowledge_sources:
|
||||||
try:
|
|
||||||
from crewai.utilities import sanitize_collection_name
|
|
||||||
knowledge_agent_name = sanitize_collection_name(self.role)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Error sanitizing collection name: {e}")
|
|
||||||
knowledge_agent_name = "default_agent"
|
|
||||||
|
|
||||||
if isinstance(self.knowledge_sources, list) and all(
|
if isinstance(self.knowledge_sources, list) and all(
|
||||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||||
):
|
):
|
||||||
self.knowledge = Knowledge(
|
self.knowledge = Knowledge(
|
||||||
sources=self.knowledge_sources,
|
sources=self.knowledge_sources,
|
||||||
embedder=self.embedder,
|
embedder=self.embedder,
|
||||||
collection_name=knowledge_agent_name,
|
collection_name=self.role,
|
||||||
storage=self.knowledge_storage or None,
|
storage=self.knowledge_storage or None,
|
||||||
)
|
)
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
|
|||||||
@@ -98,8 +98,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
else "knowledge"
|
else "knowledge"
|
||||||
)
|
)
|
||||||
if self.app:
|
if self.app:
|
||||||
|
from crewai.utilities.chromadb import sanitize_collection_name
|
||||||
|
|
||||||
self.collection = self.app.get_or_create_collection(
|
self.collection = self.app.get_or_create_collection(
|
||||||
name=collection_name, embedding_function=self.embedder
|
name=sanitize_collection_name(collection_name),
|
||||||
|
embedding_function=self.embedder,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("Vector Database Client not initialized")
|
raise Exception("Vector Database Client not initialized")
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from .parser import YamlParser
|
|||||||
from .printer import Printer
|
from .printer import Printer
|
||||||
from .prompts import Prompts
|
from .prompts import Prompts
|
||||||
from .rpm_controller import RPMController
|
from .rpm_controller import RPMController
|
||||||
from .string_utils import sanitize_collection_name, is_ipv4_pattern
|
|
||||||
from .exceptions.context_window_exceeding_exception import (
|
from .exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
@@ -26,6 +25,4 @@ __all__ = [
|
|||||||
"YamlParser",
|
"YamlParser",
|
||||||
"LLMContextLengthExceededException",
|
"LLMContextLengthExceededException",
|
||||||
"EmbeddingConfigurator",
|
"EmbeddingConfigurator",
|
||||||
"sanitize_collection_name",
|
|
||||||
"is_ipv4_pattern",
|
|
||||||
]
|
]
|
||||||
|
|||||||
62
src/crewai/utilities/chromadb.py
Normal file
62
src/crewai/utilities/chromadb.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
MIN_COLLECTION_LENGTH = 3
|
||||||
|
MAX_COLLECTION_LENGTH = 63
|
||||||
|
DEFAULT_COLLECTION = "default_collection"
|
||||||
|
|
||||||
|
# Compiled regex patterns for better performance
|
||||||
|
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
||||||
|
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||||
|
|
||||||
|
|
||||||
|
def is_ipv4_pattern(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a string matches an IPv4 address pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the string matches an IPv4 pattern, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(IPV4_PATTERN.match(name))
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_collection_name(name: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a collection name to meet ChromaDB requirements:
|
||||||
|
1. 3-63 characters long
|
||||||
|
2. Starts and ends with alphanumeric character
|
||||||
|
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||||
|
4. No consecutive periods
|
||||||
|
5. Not a valid IPv4 address
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The original collection name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sanitized collection name that meets ChromaDB requirements
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return DEFAULT_COLLECTION
|
||||||
|
|
||||||
|
if is_ipv4_pattern(name):
|
||||||
|
name = f"ip_{name}"
|
||||||
|
|
||||||
|
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||||
|
|
||||||
|
if not sanitized[0].isalnum():
|
||||||
|
sanitized = "a" + sanitized
|
||||||
|
|
||||||
|
if not sanitized[-1].isalnum():
|
||||||
|
sanitized = sanitized[:-1] + "z"
|
||||||
|
|
||||||
|
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||||
|
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||||
|
if len(sanitized) > MAX_COLLECTION_LENGTH:
|
||||||
|
sanitized = sanitized[:MAX_COLLECTION_LENGTH]
|
||||||
|
if not sanitized[-1].isalnum():
|
||||||
|
sanitized = sanitized[:-1] + "z"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
@@ -80,74 +80,3 @@ def interpolate_only(
|
|||||||
result = result.replace(placeholder, value)
|
result = result.replace(placeholder, value)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
# Constants for ChromaDB collection name requirements
|
|
||||||
MIN_LENGTH = 3
|
|
||||||
MAX_LENGTH = 63
|
|
||||||
DEFAULT_COLLECTION = "default_collection"
|
|
||||||
|
|
||||||
# Compiled regex patterns for better performance
|
|
||||||
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
|
||||||
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
|
||||||
|
|
||||||
|
|
||||||
def is_ipv4_pattern(name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a string matches an IPv4 address pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The string to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the string matches an IPv4 pattern, False otherwise
|
|
||||||
"""
|
|
||||||
return bool(IPV4_PATTERN.match(name))
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_collection_name(name: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize a collection name to meet ChromaDB requirements:
|
|
||||||
1. 3-63 characters long
|
|
||||||
2. Starts and ends with alphanumeric character
|
|
||||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
|
||||||
4. No consecutive periods
|
|
||||||
5. Not a valid IPv4 address
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The original collection name to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A sanitized collection name that meets ChromaDB requirements
|
|
||||||
"""
|
|
||||||
if not name:
|
|
||||||
return DEFAULT_COLLECTION
|
|
||||||
|
|
||||||
# Handle IPv4 pattern
|
|
||||||
if is_ipv4_pattern(name):
|
|
||||||
name = f"ip_{name}"
|
|
||||||
|
|
||||||
# Replace spaces and invalid characters with underscores
|
|
||||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
|
||||||
|
|
||||||
# Ensure it starts with alphanumeric
|
|
||||||
if not sanitized[0].isalnum():
|
|
||||||
sanitized = "a" + sanitized
|
|
||||||
|
|
||||||
# Ensure it ends with alphanumeric
|
|
||||||
if not sanitized[-1].isalnum():
|
|
||||||
sanitized = sanitized[:-1] + "z"
|
|
||||||
|
|
||||||
# Ensure length is between MIN_LENGTH-MAX_LENGTH characters
|
|
||||||
if len(sanitized) < MIN_LENGTH:
|
|
||||||
# Add padding with alphanumeric character at the end
|
|
||||||
sanitized = sanitized + "x" * (MIN_LENGTH - len(sanitized))
|
|
||||||
if len(sanitized) > MAX_LENGTH:
|
|
||||||
sanitized = sanitized[:MAX_LENGTH]
|
|
||||||
# Ensure it still ends with alphanumeric after truncation
|
|
||||||
if not sanitized[-1].isalnum():
|
|
||||||
sanitized = sanitized[:-1] + "z"
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|||||||
@@ -1621,6 +1621,38 @@ def test_agent_with_knowledge_sources():
|
|||||||
assert "red" in result.raw.lower()
|
assert "red" in result.raw.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_agent_with_knowledge_sources_extensive_role():
|
||||||
|
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||||
|
string_source = StringKnowledgeSource(content=content)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||||
|
) as MockKnowledge:
|
||||||
|
mock_knowledge_instance = MockKnowledge.return_value
|
||||||
|
mock_knowledge_instance.sources = [string_source]
|
||||||
|
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||||
|
goal="Provide information based on knowledge sources",
|
||||||
|
backstory="You have access to specific knowledge sources.",
|
||||||
|
llm=LLM(model="gpt-4o-mini"),
|
||||||
|
knowledge_sources=[string_source],
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="What is Brandon's favorite color?",
|
||||||
|
expected_output="Brandon's favorite color.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
assert "red" in result.raw.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_agent_with_knowledge_sources_works_with_copy():
|
def test_agent_with_knowledge_sources_works_with_copy():
|
||||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
81
tests/utilities/test_chromadb_utils.py
Normal file
81
tests/utilities/test_chromadb_utils.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import unittest
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.utilities.chromadb import (
|
||||||
|
MAX_COLLECTION_LENGTH,
|
||||||
|
MIN_COLLECTION_LENGTH,
|
||||||
|
is_ipv4_pattern,
|
||||||
|
sanitize_collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromadbUtils(unittest.TestCase):
|
||||||
|
def test_sanitize_collection_name_long_name(self):
|
||||||
|
"""Test sanitizing a very long collection name."""
|
||||||
|
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||||
|
sanitized = sanitize_collection_name(long_name)
|
||||||
|
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_special_chars(self):
|
||||||
|
"""Test sanitizing a name with special characters."""
|
||||||
|
special_chars = "Agent@123!#$%^&*()"
|
||||||
|
sanitized = sanitize_collection_name(special_chars)
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_short_name(self):
|
||||||
|
"""Test sanitizing a very short name."""
|
||||||
|
short_name = "A"
|
||||||
|
sanitized = sanitize_collection_name(short_name)
|
||||||
|
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_bad_ends(self):
|
||||||
|
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||||
|
bad_ends = "_Agent_"
|
||||||
|
sanitized = sanitize_collection_name(bad_ends)
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_none(self):
|
||||||
|
"""Test sanitizing a None value."""
|
||||||
|
sanitized = sanitize_collection_name(None)
|
||||||
|
self.assertEqual(sanitized, "default_collection")
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_ipv4_pattern(self):
|
||||||
|
"""Test sanitizing an IPv4 address."""
|
||||||
|
ipv4 = "192.168.1.1"
|
||||||
|
sanitized = sanitize_collection_name(ipv4)
|
||||||
|
self.assertTrue(sanitized.startswith("ip_"))
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||||
|
|
||||||
|
def test_is_ipv4_pattern(self):
|
||||||
|
"""Test IPv4 pattern detection."""
|
||||||
|
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
||||||
|
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
||||||
|
|
||||||
|
def test_sanitize_collection_name_properties(self):
|
||||||
|
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||||
|
test_cases = [
|
||||||
|
"A" * 100, # Very long name
|
||||||
|
"_start_with_underscore",
|
||||||
|
"end_with_underscore_",
|
||||||
|
"contains@special#characters",
|
||||||
|
"192.168.1.1", # IPv4 address
|
||||||
|
"a" * 2, # Too short
|
||||||
|
]
|
||||||
|
for test_case in test_cases:
|
||||||
|
sanitized = sanitize_collection_name(test_case)
|
||||||
|
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||||
|
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||||
|
self.assertTrue(sanitized[0].isalnum())
|
||||||
|
self.assertTrue(sanitized[-1].isalnum())
|
||||||
@@ -1,14 +1,8 @@
|
|||||||
import unittest
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.utilities import is_ipv4_pattern, sanitize_collection_name
|
from crewai.utilities.string_utils import interpolate_only
|
||||||
from crewai.utilities.string_utils import (
|
|
||||||
MAX_LENGTH,
|
|
||||||
MIN_LENGTH,
|
|
||||||
interpolate_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestInterpolateOnly:
|
class TestInterpolateOnly:
|
||||||
@@ -191,77 +185,3 @@ class TestInterpolateOnly:
|
|||||||
interpolate_only(template, inputs)
|
interpolate_only(template, inputs)
|
||||||
|
|
||||||
assert "inputs dictionary cannot be empty" in str(excinfo.value).lower()
|
assert "inputs dictionary cannot be empty" in str(excinfo.value).lower()
|
||||||
|
|
||||||
|
|
||||||
class TestStringUtils(unittest.TestCase):
|
|
||||||
def test_sanitize_collection_name_long_name(self):
|
|
||||||
"""Test sanitizing a very long collection name."""
|
|
||||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
|
||||||
sanitized = sanitize_collection_name(long_name)
|
|
||||||
self.assertLessEqual(len(sanitized), MAX_LENGTH)
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_special_chars(self):
|
|
||||||
"""Test sanitizing a name with special characters."""
|
|
||||||
special_chars = "Agent@123!#$%^&*()"
|
|
||||||
sanitized = sanitize_collection_name(special_chars)
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_short_name(self):
|
|
||||||
"""Test sanitizing a very short name."""
|
|
||||||
short_name = "A"
|
|
||||||
sanitized = sanitize_collection_name(short_name)
|
|
||||||
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_bad_ends(self):
|
|
||||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
|
||||||
bad_ends = "_Agent_"
|
|
||||||
sanitized = sanitize_collection_name(bad_ends)
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_none(self):
|
|
||||||
"""Test sanitizing a None value."""
|
|
||||||
sanitized = sanitize_collection_name(None)
|
|
||||||
self.assertEqual(sanitized, "default_collection")
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_ipv4_pattern(self):
|
|
||||||
"""Test sanitizing an IPv4 address."""
|
|
||||||
ipv4 = "192.168.1.1"
|
|
||||||
sanitized = sanitize_collection_name(ipv4)
|
|
||||||
self.assertTrue(sanitized.startswith("ip_"))
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
||||||
|
|
||||||
def test_is_ipv4_pattern(self):
|
|
||||||
"""Test IPv4 pattern detection."""
|
|
||||||
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
|
||||||
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
|
||||||
|
|
||||||
def test_sanitize_collection_name_properties(self):
|
|
||||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
|
||||||
test_cases = [
|
|
||||||
"A" * 100, # Very long name
|
|
||||||
"_start_with_underscore",
|
|
||||||
"end_with_underscore_",
|
|
||||||
"contains@special#characters",
|
|
||||||
"192.168.1.1", # IPv4 address
|
|
||||||
"a" * 2, # Too short
|
|
||||||
]
|
|
||||||
for test_case in test_cases:
|
|
||||||
sanitized = sanitize_collection_name(test_case)
|
|
||||||
self.assertGreaterEqual(len(sanitized), MIN_LENGTH)
|
|
||||||
self.assertLessEqual(len(sanitized), MAX_LENGTH)
|
|
||||||
self.assertTrue(sanitized[0].isalnum())
|
|
||||||
self.assertTrue(sanitized[-1].isalnum())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user