mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 05:18:16 +00:00
Compare commits
1 Commits
devin/1763
...
devin/1764
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b889621e30 |
@@ -950,34 +950,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
import re
|
||||
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
tasks=self.tasks, planning_agent_llm=self.planning_llm
|
||||
)._handle_crew_planning()
|
||||
|
||||
plan_map: dict[int, str] = {}
|
||||
for step_plan in result.list_of_plans_per_task:
|
||||
match = re.search(r"Task Number (\d+)", step_plan.task, re.IGNORECASE)
|
||||
if match:
|
||||
task_number = int(match.group(1))
|
||||
plan_map[task_number] = step_plan.plan
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Could not extract task number from plan task field: {step_plan.task}",
|
||||
)
|
||||
|
||||
for idx, task in enumerate(self.tasks):
|
||||
task_number = idx + 1 # Task numbers are 1-indexed
|
||||
if task_number in plan_map:
|
||||
task.description += plan_map[task_number]
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"No plan found for task {task_number}. Task description: {task.description}",
|
||||
)
|
||||
for task, step_plan in zip(
|
||||
self.tasks, result.list_of_plans_per_task, strict=False
|
||||
):
|
||||
task.description += step_plan.plan
|
||||
|
||||
def _store_execution_log(
|
||||
self,
|
||||
|
||||
@@ -1,21 +1,66 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
"""HuggingFace embeddings provider using the Inference API.
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
This provider uses the HuggingFace Inference API for text embeddings.
|
||||
It supports configuration via direct parameters or environment variables.
|
||||
|
||||
Example:
|
||||
embedder={
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "your-hf-token",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||
default=HuggingFaceEmbeddingFunction,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="HuggingFace API key for authentication",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"HF_TOKEN",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_MODEL",
|
||||
"HUGGINGFACE_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key_env_var: str = Field(
|
||||
default="CHROMA_HUGGINGFACE_API_KEY",
|
||||
description="Environment variable name containing the API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY_ENV_VAR",
|
||||
"HUGGINGFACE_API_KEY_ENV_VAR",
|
||||
),
|
||||
)
|
||||
api_url: str | None = Field(
|
||||
default=None,
|
||||
description="API URL (accepted for compatibility but not used by HuggingFace Inference API)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_URL",
|
||||
"HUGGINGFACE_URL",
|
||||
"url",
|
||||
),
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,24 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
"""Configuration for HuggingFace provider.
|
||||
|
||||
Supports HuggingFace Inference API for text embeddings.
|
||||
|
||||
Attributes:
|
||||
api_key: HuggingFace API key for authentication.
|
||||
model: Model name to use for embeddings (e.g., "sentence-transformers/all-MiniLM-L6-v2").
|
||||
model_name: Alias for model.
|
||||
api_key_env_var: Environment variable name containing the API key.
|
||||
api_url: Optional API URL (accepted but not used, for compatibility).
|
||||
url: Alias for api_url (accepted but not used, for compatibility).
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model: str
|
||||
model_name: str
|
||||
api_key_env_var: str
|
||||
api_url: str
|
||||
url: str
|
||||
|
||||
|
||||
|
||||
@@ -176,6 +176,98 @@ class TestEmbeddingFactory:
|
||||
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface(self, mock_import):
|
||||
"""Test building HuggingFace embedder with api_key and model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface_with_api_url(self, mock_import):
|
||||
"""Test building HuggingFace embedder with api_url (for compatibility)."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||
"api_url": "https://api-inference.huggingface.co",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model"] == "Qwen/Qwen3-Embedding-0.6B"
|
||||
assert call_kwargs["api_url"] == "https://api-inference.huggingface.co"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface_with_model_name(self, mock_import):
|
||||
"""Test building HuggingFace embedder with model_name alias."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model_name"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_unknown_provider(self):
|
||||
"""Test error handling for unknown provider."""
|
||||
config = {"provider": "unknown-provider", "config": {}}
|
||||
|
||||
143
lib/crewai/tests/rag/embeddings/test_huggingface_provider.py
Normal file
143
lib/crewai/tests/rag/embeddings/test_huggingface_provider.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for HuggingFace embedding provider."""
|
||||
|
||||
import pytest
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
|
||||
|
||||
class TestHuggingFaceProvider:
|
||||
"""Test HuggingFace embedding provider."""
|
||||
|
||||
def test_provider_with_api_key_and_model(self):
|
||||
"""Test provider initialization with api_key and model.
|
||||
|
||||
This tests the fix for GitHub issue #3995 where users couldn't
|
||||
configure HuggingFace embedder with api_key and model.
|
||||
"""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert provider.embedding_callable == HuggingFaceEmbeddingFunction
|
||||
|
||||
def test_provider_with_model_alias(self):
|
||||
"""Test provider initialization with 'model' alias for model_name."""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model="Qwen/Qwen3-Embedding-0.6B",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||
|
||||
def test_provider_with_api_url_compatibility(self):
|
||||
"""Test provider accepts api_url for compatibility but excludes it from model_dump.
|
||||
|
||||
The api_url parameter is accepted for compatibility with the documented
|
||||
configuration format but is not passed to HuggingFaceEmbeddingFunction
|
||||
since it uses a fixed API endpoint.
|
||||
"""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_url="https://api-inference.huggingface.co",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert provider.api_url == "https://api-inference.huggingface.co"
|
||||
|
||||
# api_url should be excluded from model_dump
|
||||
dumped = provider.model_dump(exclude={"embedding_callable"})
|
||||
assert "api_url" not in dumped
|
||||
|
||||
def test_provider_default_model(self):
|
||||
"""Test provider uses default model when not specified."""
|
||||
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_provider_default_api_key_env_var(self):
|
||||
"""Test provider uses default api_key_env_var."""
|
||||
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||
|
||||
assert provider.api_key_env_var == "CHROMA_HUGGINGFACE_API_KEY"
|
||||
|
||||
|
||||
class TestHuggingFaceProviderIntegration:
|
||||
"""Integration tests for HuggingFace provider with build_embedder."""
|
||||
|
||||
def test_build_embedder_with_documented_config(self):
|
||||
"""Test build_embedder with the documented configuration format.
|
||||
|
||||
This tests the exact configuration format shown in the documentation
|
||||
that was failing before the fix for GitHub issue #3995.
|
||||
"""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"api_url": "https://api-inference.huggingface.co",
|
||||
},
|
||||
}
|
||||
|
||||
# This should not raise a validation error
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_minimal_config(self):
|
||||
"""Test build_embedder with minimal configuration."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
# Default model should be used
|
||||
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_model_name_config(self):
|
||||
"""Test build_embedder with model_name instead of model."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model_name": "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_custom_model(self):
|
||||
"""Test build_embedder with a custom model name."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||
@@ -4772,93 +4772,3 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
|
||||
|
||||
def test_crew_planning_with_mismatched_task_order():
|
||||
"""Test that crew planning correctly matches plans to tasks even when LLM returns them out of order.
|
||||
|
||||
This test reproduces the bug reported in issue #3953 where the task planner
|
||||
returns plans in the wrong order (e.g., starting with Task 21 instead of Task 1),
|
||||
causing plans to be attached to the wrong tasks.
|
||||
"""
|
||||
from crewai.utilities.planning_handler import PlanPerTask, PlannerTaskPydanticOutput
|
||||
|
||||
# Create 5 tasks with distinct descriptions
|
||||
tasks = []
|
||||
agents = []
|
||||
for i in range(1, 6):
|
||||
agent = Agent(
|
||||
role=f"Agent {i}",
|
||||
goal=f"Goal {i}",
|
||||
backstory=f"Backstory {i}",
|
||||
)
|
||||
agents.append(agent)
|
||||
task = Task(
|
||||
description=f"Task {i} description",
|
||||
expected_output=f"Output {i}",
|
||||
agent=agent,
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
planning=True,
|
||||
planning_llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
# Mock the LLM response to return plans in the WRONG order
|
||||
# Simulating the bug where Task 5 plan comes first, then Task 3, etc.
|
||||
wrong_order_plans = [
|
||||
PlanPerTask(
|
||||
task="Task Number 5 - Task 5 description",
|
||||
plan="\n\nPlan for task 5"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 3 - Task 3 description",
|
||||
plan="\n\nPlan for task 3"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 1 - Task 1 description",
|
||||
plan="\n\nPlan for task 1"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 4 - Task 4 description",
|
||||
plan="\n\nPlan for task 4"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 2 - Task 2 description",
|
||||
plan="\n\nPlan for task 2"
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute:
|
||||
mock_execute.return_value = TaskOutput(
|
||||
description="Planning task",
|
||||
agent="planner",
|
||||
pydantic=PlannerTaskPydanticOutput(
|
||||
list_of_plans_per_task=wrong_order_plans
|
||||
),
|
||||
)
|
||||
|
||||
# Call the planning method
|
||||
crew._handle_crew_planning()
|
||||
|
||||
# Verify that each task has the CORRECT plan appended to its description
|
||||
# Task 1 should have "Plan for task 1", not "Plan for task 5"
|
||||
assert "Plan for task 1" in crew.tasks[0].description, \
|
||||
f"Task 1 should have 'Plan for task 1' but got: {crew.tasks[0].description}"
|
||||
assert "Plan for task 2" in crew.tasks[1].description, \
|
||||
f"Task 2 should have 'Plan for task 2' but got: {crew.tasks[1].description}"
|
||||
assert "Plan for task 3" in crew.tasks[2].description, \
|
||||
f"Task 3 should have 'Plan for task 3' but got: {crew.tasks[2].description}"
|
||||
assert "Plan for task 4" in crew.tasks[3].description, \
|
||||
f"Task 4 should have 'Plan for task 4' but got: {crew.tasks[3].description}"
|
||||
assert "Plan for task 5" in crew.tasks[4].description, \
|
||||
f"Task 5 should have 'Plan for task 5' but got: {crew.tasks[4].description}"
|
||||
|
||||
# Also verify that wrong plans are NOT in the wrong tasks
|
||||
assert "Plan for task 5" not in crew.tasks[0].description, \
|
||||
"Task 1 should not have Plan for task 5"
|
||||
assert "Plan for task 3" not in crew.tasks[1].description, \
|
||||
"Task 2 should not have Plan for task 3"
|
||||
|
||||
Reference in New Issue
Block a user