From 8aec7b3364f1e1085ada6c29ae50682d83d81141 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 03:22:00 +0000 Subject: [PATCH] Fix #2534: Handle non-ASCII characters in agent roles for knowledge sources Co-Authored-By: Joe Moura --- src/crewai/agent.py | 4 +- src/crewai/utilities/__init__.py | 2 + src/crewai/utilities/string_utils.py | 37 +++++++++++++++++ tests/test_agent_non_ascii.py | 60 ++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/test_agent_non_ascii.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 14c6d7bad..b9335e3fd 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -144,13 +144,15 @@ class Agent(BaseAgent): self.embedder = crew_embedder if self.knowledge_sources: + from crewai.utilities import sanitize_collection_name + if isinstance(self.knowledge_sources, list) and all( isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources ): self.knowledge = Knowledge( sources=self.knowledge_sources, embedder=self.embedder, - collection_name=self.role, + collection_name=sanitize_collection_name(self.role), storage=self.knowledge_storage or None, ) except (TypeError, ValueError) as e: diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index dd6d9fa44..a626099e7 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -11,6 +11,7 @@ from .exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) from .embedding_configurator import EmbeddingConfigurator +from .string_utils import sanitize_collection_name __all__ = [ "Converter", @@ -25,4 +26,5 @@ __all__ = [ "YamlParser", "LLMContextLengthExceededException", "EmbeddingConfigurator", + "sanitize_collection_name", ] diff --git a/src/crewai/utilities/string_utils.py b/src/crewai/utilities/string_utils.py index 9a1857781..0d6bac46c 100644 --- a/src/crewai/utilities/string_utils.py +++ b/src/crewai/utilities/string_utils.py @@ -1,5 +1,6 @@ import re from typing import Any, Dict, List, Optional, Union +from unidecode import unidecode def interpolate_only( @@ -80,3 +81,39 @@ def interpolate_only( result = result.replace(placeholder, value) return result + + +def sanitize_collection_name(name: str) -> str: + """ + Sanitizes a string to be used as a ChromaDB collection name. + + ChromaDB collection names must: + 1. Contain 3-63 characters + 2. Start and end with an alphanumeric character + 3. Otherwise contain only alphanumeric characters, underscores or hyphens (-) + 4. Contain no two consecutive periods (..) + 5. Not be a valid IPv4 address + + Args: + name: The string to sanitize + + Returns: + A sanitized string that can be used as a ChromaDB collection name + """ + name = unidecode(name) + + name = re.sub(r'[^\w\-]', '_', name) + + name = re.sub(r'_+', '_', name) + + name = re.sub(r'^[^a-zA-Z0-9]+', '', name) + name = re.sub(r'[^a-zA-Z0-9]+$', '', name) + + if len(name) < 3: + name = name + 'x' * (3 - len(name)) + + if len(name) > 63: + name = name[:63] + name = re.sub(r'[^a-zA-Z0-9]+$', '', name) + + return name diff --git a/tests/test_agent_non_ascii.py b/tests/test_agent_non_ascii.py new file mode 100644 index 000000000..2ddc5140f --- /dev/null +++ b/tests/test_agent_non_ascii.py @@ -0,0 +1,60 @@ +import pytest +from crewai.utilities import sanitize_collection_name + + +def test_sanitize_collection_name_with_non_ascii_chars(): + """Test that sanitize_collection_name properly handles non-ASCII characters.""" + chinese_role = "一位有 20 年经验的 GraphQL 查询专家" + sanitized_name = sanitize_collection_name(chinese_role) + + assert len(sanitized_name) >= 3 + assert len(sanitized_name) <= 63 + assert sanitized_name[0].isalnum() + assert sanitized_name[-1].isalnum() + assert all(c.isalnum() or c == '_' or c == '-' for c in sanitized_name) + assert '__' not in sanitized_name # No consecutive underscores + + special_chars_role = "Café Owner & Barista (España) 🇪🇸" + sanitized_name = sanitize_collection_name(special_chars_role) + + assert len(sanitized_name) >= 3 + assert len(sanitized_name) <= 63 + assert sanitized_name[0].isalnum() + assert sanitized_name[-1].isalnum() + assert all(c.isalnum() or c == '_' or c == '-' for c in sanitized_name) + assert '__' not in sanitized_name # No consecutive underscores + + +def test_sanitize_collection_name_edge_cases(): + """Test edge cases for sanitize_collection_name function.""" + empty_role = "" + sanitized_name = sanitize_collection_name(empty_role) + assert len(sanitized_name) >= 3 # Should be padded to minimum length + + special_only = "!@#$%^&*()" + sanitized_name = sanitize_collection_name(special_only) + assert len(sanitized_name) >= 3 + assert sanitized_name[0].isalnum() + assert sanitized_name[-1].isalnum() + + long_role = "a" * 100 + sanitized_name = sanitize_collection_name(long_role) + assert len(sanitized_name) <= 63 + + consecutive_spaces = "Hello World" + sanitized_name = sanitize_collection_name(consecutive_spaces) + assert "__" not in sanitized_name + + +def test_sanitize_collection_name_reproduces_issue_2534(): + """Test that reproduces the specific issue from #2534.""" + problematic_role = "一位有 20 年经验的 GraphQL 查询专家" + + sanitized_name = sanitize_collection_name(problematic_role) + + assert len(sanitized_name) >= 3 + assert len(sanitized_name) <= 63 + assert sanitized_name[0].isalnum() + assert sanitized_name[-1].isalnum() + assert all(c.isalnum() or c == '_' or c == '-' for c in sanitized_name) + assert '__' not in sanitized_name # No consecutive underscores