mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
Compare commits
2 Commits
devin/1762
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70379689cf | ||
|
|
e891563135 |
@@ -135,13 +135,42 @@ class EmbeddingConfigurator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
def _normalize_api_url(api_url: str) -> str:
|
||||
"""
|
||||
Normalize API URL by ensuring it has a protocol.
|
||||
|
||||
Args:
|
||||
api_url: The API URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL with protocol (defaults to http:// if missing)
|
||||
"""
|
||||
if not (api_url.startswith("http://") or api_url.startswith("https://")):
|
||||
return f"http://{api_url}"
|
||||
return api_url
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config: dict, model_name: str):
|
||||
"""
|
||||
Configure Huggingface embedding function with the provided config.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary for the Huggingface embedder
|
||||
model_name: Name of the model to use
|
||||
|
||||
Returns:
|
||||
Configured HuggingFaceEmbeddingServer instance
|
||||
"""
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
api_url = config.get("api_url")
|
||||
if api_url:
|
||||
api_url = EmbeddingConfigurator._normalize_api_url(api_url)
|
||||
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
url=api_url,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Test Flow creation and execution basic functionality."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
@@ -324,91 +322,3 @@ def test_router_with_multiple_conditions():
|
||||
|
||||
# final_step should run after router_and
|
||||
assert execution_order.index("log_final_step") > execution_order.index("router_and")
|
||||
|
||||
|
||||
def test_flow_with_rlock_in_state():
|
||||
"""Test that Flow can handle unpickleable objects like RLock in state.
|
||||
|
||||
Regression test for issue #3828: Flow should not crash when state contains
|
||||
objects that cannot be deep copied (like threading.RLock).
|
||||
|
||||
In version 1.3.0, Flow._copy_state() used copy.deepcopy() which would fail
|
||||
with "TypeError: cannot pickle '_thread.RLock' object" when state contained
|
||||
threading locks (e.g., from memory components or LLM instances).
|
||||
|
||||
The current implementation no longer deep copies state, so this test verifies
|
||||
that flows with unpickleable objects in state work correctly.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class StateWithRLock(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
counter: int = 0
|
||||
lock: threading.RLock = None
|
||||
|
||||
class FlowWithRLock(Flow[StateWithRLock]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
self.state.counter += 1
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
self.state.counter += 1
|
||||
|
||||
flow = FlowWithRLock()
|
||||
flow._state.lock = threading.RLock()
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
assert execution_order == ["step_1", "step_2"]
|
||||
assert flow.state.counter == 2
|
||||
|
||||
|
||||
def test_flow_with_nested_unpickleable_objects():
|
||||
"""Test that Flow can handle unpickleable objects nested in containers.
|
||||
|
||||
Regression test for issue #3828: Verifies that unpickleable objects
|
||||
nested inside dicts/lists in state don't cause crashes.
|
||||
|
||||
This simulates real-world scenarios where memory components or other
|
||||
resources with locks might be stored in nested data structures.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class NestedState(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
data: dict = {}
|
||||
items: list = []
|
||||
|
||||
class FlowWithNestedUnpickleable(Flow[NestedState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
self.state.data["lock"] = threading.RLock()
|
||||
self.state.data["value"] = 42
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
self.state.items.append(threading.Lock())
|
||||
self.state.items.append("normal_value")
|
||||
|
||||
@listen(step_2)
|
||||
def step_3(self):
|
||||
execution_order.append("step_3")
|
||||
assert self.state.data["value"] == 42
|
||||
assert len(self.state.items) == 2
|
||||
|
||||
flow = FlowWithNestedUnpickleable()
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
assert execution_order == ["step_1", "step_2", "step_3"]
|
||||
assert flow.state.data["value"] == 42
|
||||
assert len(flow.state.items) == 2
|
||||
|
||||
@@ -584,3 +584,84 @@ def test_docling_source_with_local_file():
|
||||
docling_source = CrewDoclingSource(file_paths=[pdf_path])
|
||||
assert docling_source.file_paths == [pdf_path]
|
||||
assert docling_source.content is not None
|
||||
|
||||
|
||||
def test_huggingface_url_validation():
|
||||
"""Test that Huggingface embedder properly handles URLs without protocol."""
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
config_missing_protocol = {
|
||||
"api_url": "localhost:8080/embed"
|
||||
}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config_missing_protocol, "test-model"
|
||||
)
|
||||
# Verify that the URL now has a protocol
|
||||
assert embedding_function._api_url.startswith("http://")
|
||||
|
||||
config_with_protocol = {
|
||||
"api_url": "https://localhost:8080/embed"
|
||||
}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config_with_protocol, "test-model"
|
||||
)
|
||||
# Verify that the URL remains unchanged
|
||||
assert embedding_function._api_url == "https://localhost:8080/embed"
|
||||
|
||||
config_with_other_protocol = {
|
||||
"api_url": "http://localhost:8080/embed"
|
||||
}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config_with_other_protocol, "test-model"
|
||||
)
|
||||
# Verify that the URL remains unchanged
|
||||
assert embedding_function._api_url == "http://localhost:8080/embed"
|
||||
|
||||
config_no_url = {}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config_no_url, "test-model"
|
||||
)
|
||||
# Verify that no exception is raised when URL is None
|
||||
assert embedding_function._api_url == 'None'
|
||||
|
||||
|
||||
def test_huggingface_missing_protocol_with_json_source():
|
||||
"""Test that JSONKnowledgeSource works with Huggingface embedder without URL protocol."""
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
# Create a temporary JSON file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp:
|
||||
json.dump({"test": "data", "nested": {"value": 123}}, temp)
|
||||
json_path = temp.name
|
||||
|
||||
# Test that the URL validation works in the embedder configurator
|
||||
config = {
|
||||
"api_url": "localhost:8080/embed" # Missing protocol
|
||||
}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config, "test-model"
|
||||
)
|
||||
# Verify that the URL now has a protocol
|
||||
assert embedding_function._api_url.startswith("http://")
|
||||
|
||||
os.unlink(json_path)
|
||||
|
||||
|
||||
def test_huggingface_missing_protocol_with_string_source():
|
||||
"""Test that StringKnowledgeSource works with Huggingface embedder without URL protocol."""
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
# Test that the URL validation works in the embedder configurator
|
||||
config = {
|
||||
"api_url": "localhost:8080/embed" # Missing protocol
|
||||
}
|
||||
embedding_function = EmbeddingConfigurator()._configure_huggingface(
|
||||
config, "test-model"
|
||||
)
|
||||
# Verify that the URL now has a protocol
|
||||
assert embedding_function._api_url.startswith("http://")
|
||||
|
||||
6
tests/knowledge/test_data.json
Normal file
6
tests/knowledge/test_data.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"test": "data",
|
||||
"nested": {
|
||||
"value": 123
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user