Merge branch 'main' of github.com:crewAIInc/crewAI into fix/interpolate-only-for-dict-list-input-types

This commit is contained in:
Lorenze Jay
2025-01-29 10:11:50 -08:00
14 changed files with 630 additions and 95 deletions

View File

@@ -43,7 +43,7 @@ Think of an agent as a specialized team member with specific skills, expertise,
| **Max Retry Limit** _(optional)_ | `max_retry_limit` | `int` | Maximum number of retries when an error occurs. Default is 2. | | **Max Retry Limit** _(optional)_ | `max_retry_limit` | `int` | Maximum number of retries when an error occurs. Default is 2. |
| **Respect Context Window** _(optional)_ | `respect_context_window` | `bool` | Keep messages under context window size by summarizing. Default is True. | | **Respect Context Window** _(optional)_ | `respect_context_window` | `bool` | Keep messages under context window size by summarizing. Default is True. |
| **Code Execution Mode** _(optional)_ | `code_execution_mode` | `Literal["safe", "unsafe"]` | Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct). Default is 'safe'. | | **Code Execution Mode** _(optional)_ | `code_execution_mode` | `Literal["safe", "unsafe"]` | Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct). Default is 'safe'. |
| **Embedder Config** _(optional)_ | `embedder_config` | `Optional[Dict[str, Any]]` | Configuration for the embedder used by the agent. | | **Embedder** _(optional)_ | `embedder` | `Optional[Dict[str, Any]]` | Configuration for the embedder used by the agent. |
| **Knowledge Sources** _(optional)_ | `knowledge_sources` | `Optional[List[BaseKnowledgeSource]]` | Knowledge sources available to the agent. | | **Knowledge Sources** _(optional)_ | `knowledge_sources` | `Optional[List[BaseKnowledgeSource]]` | Knowledge sources available to the agent. |
| **Use System Prompt** _(optional)_ | `use_system_prompt` | `Optional[bool]` | Whether to use system prompt (for o1 model support). Default is True. | | **Use System Prompt** _(optional)_ | `use_system_prompt` | `Optional[bool]` | Whether to use system prompt (for o1 model support). Default is True. |
@@ -152,7 +152,7 @@ agent = Agent(
use_system_prompt=True, # Default: True use_system_prompt=True, # Default: True
tools=[SerperDevTool()], # Optional: List of tools tools=[SerperDevTool()], # Optional: List of tools
knowledge_sources=None, # Optional: List of knowledge sources knowledge_sources=None, # Optional: List of knowledge sources
embedder_config=None, # Optional: Custom embedder configuration embedder=None, # Optional: Custom embedder configuration
system_template=None, # Optional: Custom system prompt template system_template=None, # Optional: Custom system prompt template
prompt_template=None, # Optional: Custom prompt template prompt_template=None, # Optional: Custom prompt template
response_template=None, # Optional: Custom response template response_template=None, # Optional: Custom response template

View File

@@ -324,6 +324,13 @@ agent = Agent(
verbose=True, verbose=True,
allow_delegation=False, allow_delegation=False,
llm=gemini_llm, llm=gemini_llm,
embedder={
"provider": "google",
"config": {
"model": "models/text-embedding-004",
"api_key": GEMINI_API_KEY,
}
}
) )
task = Task( task = Task(

View File

@@ -33,11 +33,12 @@ crew = Crew(
| :------------------------------- | :---------------- | :---------------------------- | :------------------------------------------------------------------------------------------------------------------- | | :------------------------------- | :---------------- | :---------------------------- | :------------------------------------------------------------------------------------------------------------------- |
| **Description** | `description` | `str` | A clear, concise statement of what the task entails. | | **Description** | `description` | `str` | A clear, concise statement of what the task entails. |
| **Expected Output** | `expected_output` | `str` | A detailed description of what the task's completion looks like. | | **Expected Output** | `expected_output` | `str` | A detailed description of what the task's completion looks like. |
| **Name** _(optional)_ | `name` | `Optional[str]` | A name identifier for the task. | | **Name** _(optional)_ | `name` | `Optional[str]` | A name identifier for the task. |
| **Agent** _(optional)_ | `agent` | `Optional[BaseAgent]` | The agent responsible for executing the task. | | **Agent** _(optional)_ | `agent` | `Optional[BaseAgent]` | The agent responsible for executing the task. |
| **Tools** _(optional)_ | `tools` | `List[BaseTool]` | The tools/resources the agent is limited to use for this task. | | **Tools** _(optional)_ | `tools` | `List[BaseTool]` | The tools/resources the agent is limited to use for this task. |
| **Context** _(optional)_ | `context` | `Optional[List["Task"]]` | Other tasks whose outputs will be used as context for this task. | | **Context** _(optional)_ | `context` | `Optional[List["Task"]]` | Other tasks whose outputs will be used as context for this task. |
| **Async Execution** _(optional)_ | `async_execution` | `Optional[bool]` | Whether the task should be executed asynchronously. Defaults to False. | | **Async Execution** _(optional)_ | `async_execution` | `Optional[bool]` | Whether the task should be executed asynchronously. Defaults to False. |
| **Human Input** _(optional)_ | `human_input` | `Optional[bool]` | Whether the task should have a human review the final answer of the agent. Defaults to False. |
| **Config** _(optional)_ | `config` | `Optional[Dict[str, Any]]` | Task-specific configuration parameters. | | **Config** _(optional)_ | `config` | `Optional[Dict[str, Any]]` | Task-specific configuration parameters. |
| **Output File** _(optional)_ | `output_file` | `Optional[str]` | File path for storing the task output. | | **Output File** _(optional)_ | `output_file` | `Optional[str]` | File path for storing the task output. |
| **Output JSON** _(optional)_ | `output_json` | `Optional[Type[BaseModel]]` | A Pydantic model to structure the JSON output. | | **Output JSON** _(optional)_ | `output_json` | `Optional[Type[BaseModel]]` | A Pydantic model to structure the JSON output. |

View File

@@ -61,6 +61,7 @@ class Agent(BaseAgent):
tools: Tools at agents disposal tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution. step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent. knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
""" """
_times_executed: int = PrivateAttr(default=0) _times_executed: int = PrivateAttr(default=0)
@@ -122,17 +123,10 @@ class Agent(BaseAgent):
default="safe", default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
) )
embedder_config: Optional[Dict[str, Any]] = Field( embedder: Optional[Dict[str, Any]] = Field(
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
@@ -163,10 +157,11 @@ class Agent(BaseAgent):
if isinstance(self.knowledge_sources, list) and all( if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
): ):
self._knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder_config=self.embedder_config, embedder=self.embedder,
collection_name=knowledge_agent_name, collection_name=knowledge_agent_name,
storage=self.knowledge_storage or None,
) )
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}") raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
@@ -225,8 +220,8 @@ class Agent(BaseAgent):
if memory.strip() != "": if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory) task_prompt += self.i18n.slice("memory").format(memory=memory)
if self._knowledge: if self.knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()]) agent_knowledge_snippets = self.knowledge.query([task.prompt()])
if agent_knowledge_snippets: if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context( agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets agent_knowledge_snippets

View File

@@ -18,6 +18,8 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController from crewai.utilities import I18N, Logger, RPMController
@@ -48,6 +50,8 @@ class BaseAgent(ABC, BaseModel):
cache_handler (InstanceOf[CacheHandler]): An instance of the CacheHandler class. cache_handler (InstanceOf[CacheHandler]): An instance of the CacheHandler class.
tools_handler (InstanceOf[ToolsHandler]): An instance of the ToolsHandler class. tools_handler (InstanceOf[ToolsHandler]): An instance of the ToolsHandler class.
max_tokens: Maximum number of tokens for the agent to generate in a response. max_tokens: Maximum number of tokens for the agent to generate in a response.
knowledge_sources: Knowledge sources for the agent.
knowledge_storage: Custom knowledge storage for the agent.
Methods: Methods:
@@ -130,6 +134,17 @@ class BaseAgent(ABC, BaseModel):
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution." default=None, description="Maximum number of tokens for the agent's execution."
) )
knowledge: Optional[Knowledge] = Field(
default=None, description="Knowledge for the agent."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Optional[Any] = Field(
default=None,
description="Custom knowledge storage for the agent.",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -256,13 +271,44 @@ class BaseAgent(ABC, BaseModel):
"tools_handler", "tools_handler",
"cache_handler", "cache_handler",
"llm", "llm",
"knowledge_sources",
"knowledge_storage",
"knowledge",
} }
# Copy llm and clear callbacks # Copy llm
existing_llm = shallow_copy(self.llm) existing_llm = shallow_copy(self.llm)
copied_knowledge = shallow_copy(self.knowledge)
copied_knowledge_storage = shallow_copy(self.knowledge_storage)
# Properly copy knowledge sources if they exist
existing_knowledge_sources = None
if self.knowledge_sources:
# Create a shared storage instance for all knowledge sources
shared_storage = (
self.knowledge_sources[0].storage if self.knowledge_sources else None
)
existing_knowledge_sources = []
for source in self.knowledge_sources:
copied_source = (
source.model_copy()
if hasattr(source, "model_copy")
else shallow_copy(source)
)
# Ensure all copied sources use the same storage instance
copied_source.storage = shared_storage
existing_knowledge_sources.append(copied_source)
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools) copied_agent = type(self)(
**copied_data,
llm=existing_llm,
tools=self.tools,
knowledge_sources=existing_knowledge_sources,
knowledge=copied_knowledge,
knowledge_storage=copied_knowledge_storage,
)
return copied_agent return copied_agent

View File

@@ -4,6 +4,7 @@ import re
import uuid import uuid
import warnings import warnings
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -210,8 +211,9 @@ class Crew(BaseModel):
default=None, default=None,
description="LLM used to handle chatting with the crew.", description="LLM used to handle chatting with the crew.",
) )
_knowledge: Optional[Knowledge] = PrivateAttr( knowledge: Optional[Knowledge] = Field(
default=None, default=None,
description="Knowledge for the crew.",
) )
@field_validator("id", mode="before") @field_validator("id", mode="before")
@@ -289,7 +291,7 @@ class Crew(BaseModel):
if isinstance(self.knowledge_sources, list) and all( if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
): ):
self._knowledge = Knowledge( self.knowledge = Knowledge(
sources=self.knowledge_sources, sources=self.knowledge_sources,
embedder_config=self.embedder, embedder_config=self.embedder,
collection_name="crew", collection_name="crew",
@@ -991,8 +993,8 @@ class Crew(BaseModel):
return result return result
def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]: def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
if self._knowledge: if self.knowledge:
return self._knowledge.query(query) return self.knowledge.query(query)
return None return None
def fetch_inputs(self) -> Set[str]: def fetch_inputs(self) -> Set[str]:
@@ -1036,6 +1038,8 @@ class Crew(BaseModel):
"_telemetry", "_telemetry",
"agents", "agents",
"tasks", "tasks",
"knowledge_sources",
"knowledge",
} }
cloned_agents = [agent.copy() for agent in self.agents] cloned_agents = [agent.copy() for agent in self.agents]
@@ -1043,6 +1047,9 @@ class Crew(BaseModel):
task_mapping = {} task_mapping = {}
cloned_tasks = [] cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
existing_knowledge = shallow_copy(self.knowledge)
for task in self.tasks: for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping) cloned_task = task.copy(cloned_agents, task_mapping)
cloned_tasks.append(cloned_task) cloned_tasks.append(cloned_task)
@@ -1062,7 +1069,13 @@ class Crew(BaseModel):
copied_data.pop("agents", None) copied_data.pop("agents", None)
copied_data.pop("tasks", None) copied_data.pop("tasks", None)
copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks) copied_crew = Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
knowledge_sources=existing_knowledge_sources,
knowledge=existing_knowledge,
)
return copied_crew return copied_crew

View File

@@ -15,20 +15,20 @@ class Knowledge(BaseModel):
Args: Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None embedder: Optional[Dict[str, Any]] = None
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None collection_name: Optional[str] = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: List[BaseKnowledgeSource],
embedder_config: Optional[Dict[str, Any]] = None, embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None, storage: Optional[KnowledgeStorage] = None,
**data, **data,
): ):
@@ -37,25 +37,23 @@ class Knowledge(BaseModel):
self.storage = storage self.storage = storage
else: else:
self.storage = KnowledgeStorage( self.storage = KnowledgeStorage(
embedder_config=embedder_config, collection_name=collection_name embedder=embedder, collection_name=collection_name
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()
for source in sources: self._add_sources()
source.storage = self.storage
source.add()
def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]: def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
""" """
Query across all knowledge sources to find the most relevant information. Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks. Returns the top_k most relevant chunks.
Raises: Raises:
ValueError: If storage is not initialized. ValueError: If storage is not initialized.
""" """
if self.storage is None: if self.storage is None:
raise ValueError("Storage is not initialized.") raise ValueError("Storage is not initialized.")
results = self.storage.search( results = self.storage.search(
query, query,
limit, limit,
@@ -63,6 +61,9 @@ class Knowledge(BaseModel):
return results return results
def _add_sources(self): def _add_sources(self):
for source in self.sources: try:
source.storage = self.storage for source in self.sources:
source.add() source.storage = self.storage
source.add()
except Exception as e:
raise e

View File

@@ -29,7 +29,13 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def validate_file_path(cls, v, info): def validate_file_path(cls, v, info):
"""Validate that at least one of file_path or file_paths is provided.""" """Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions # Single check if both are None, O(1) instead of nested conditions
if v is None and info.data.get("file_path" if info.field_name == "file_paths" else "file_paths") is None: if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths"
)
is None
):
raise ValueError("Either file_path or file_paths must be provided") raise ValueError("Either file_path or file_paths must be provided")
return v return v

View File

@@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__( def __init__(
self, self,
embedder_config: Optional[Dict[str, Any]] = None, embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
self.collection_name = collection_name self.collection_name = collection_name
self._set_embedder_config(embedder_config) self._set_embedder_config(embedder)
def search( def search(
self, self,
@@ -99,7 +99,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) )
if self.app: if self.app:
self.collection = self.app.get_or_create_collection( self.collection = self.app.get_or_create_collection(
name=collection_name, embedding_function=self.embedder_config name=collection_name, embedding_function=self.embedder
) )
else: else:
raise Exception("Vector Database Client not initialized") raise Exception("Vector Database Client not initialized")
@@ -187,17 +187,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
) )
def _set_embedder_config( def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
self, embedder_config: Optional[Dict[str, Any]] = None
) -> None:
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function. If None or empty, defaults to the default embedding function.
""" """
self.embedder_config = ( self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder_config) EmbeddingConfigurator().configure_embedder(embedder)
if embedder_config if embedder
else self._create_default_embedding_function() else self._create_default_embedding_function()
) )

View File

@@ -43,7 +43,6 @@ class EmbeddingConfigurator:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name) return self.embedding_functions[provider](config, model_name)
@staticmethod @staticmethod

View File

@@ -10,13 +10,14 @@ from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.parser import AgentAction, CrewAgentParser, OutputParserException from crewai.agents.parser import AgentAction, CrewAgentParser, OutputParserException
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.tools import tool from crewai.tools import tool
from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage from crewai.tools.tool_usage import ToolUsage
from crewai.tools.tool_usage_events import ToolUsageFinished from crewai.tools.tool_usage_events import ToolUsageFinished
from crewai.utilities import Printer, RPMController from crewai.utilities import RPMController
from crewai.utilities.events import Emitter from crewai.utilities.events import Emitter
@@ -1602,6 +1603,45 @@ def test_agent_with_knowledge_sources():
assert "red" in result.raw.lower() assert "red" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources_works_with_copy():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch(
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
autospec=True,
) as MockKnowledgeSource:
mock_knowledge_source_instance = MockKnowledgeSource.return_value
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
mock_knowledge_source_instance.sources = [string_source]
agent = Agent(
role="Information Agent",
goal="Provide information based on knowledge sources",
backstory="You have access to specific knowledge sources.",
llm=LLM(model="gpt-4o-mini"),
knowledge_sources=[string_source],
)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledgeStorage:
mock_knowledge_storage = MockKnowledgeStorage.return_value
agent.knowledge_storage = mock_knowledge_storage
agent_copy = agent.copy()
assert agent_copy.role == agent.role
assert agent_copy.goal == agent.goal
assert agent_copy.backstory == agent.backstory
assert agent_copy.knowledge_sources is not None
assert len(agent_copy.knowledge_sources) == 1
assert isinstance(agent_copy.knowledge_sources[0], StringKnowledgeSource)
assert agent_copy.knowledge_sources[0].content == content
assert isinstance(agent_copy.llm, LLM)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_litellm_auth_error_handling(): def test_litellm_auth_error_handling():
"""Test that LiteLLM authentication errors are handled correctly and not retried.""" """Test that LiteLLM authentication errors are handled correctly and not retried."""

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -14,6 +14,7 @@ from crewai.agent import Agent
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.crew import Crew from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.process import Process from crewai.process import Process
from crewai.project import crew from crewai.project import crew
@@ -555,12 +556,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools():
_, kwargs = mock_execute_sync.call_args _, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"] tools = kwargs["tools"]
assert any( assert any(isinstance(tool, TestTool) for tool in tools), (
isinstance(tool, TestTool) for tool in tools "TestTool should be present"
), "TestTool should be present" )
assert any( assert any("delegate" in tool.name.lower() for tool in tools), (
"delegate" in tool.name.lower() for tool in tools "Delegation tool should be present"
), "Delegation tool should be present" )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -619,12 +620,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools():
_, kwargs = mock_execute_sync.call_args _, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"] tools = kwargs["tools"]
assert any( assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
isinstance(tool, TestTool) for tool in new_ceo.tools "TestTool should be present"
), "TestTool should be present" )
assert any( assert any("delegate" in tool.name.lower() for tool in tools), (
"delegate" in tool.name.lower() for tool in tools "Delegation tool should be present"
), "Delegation tool should be present" )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -748,17 +749,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation():
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Confirm AnotherTestTool is present but TestTool is not # Confirm AnotherTestTool is present but TestTool is not
assert any( assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
isinstance(tool, AnotherTestTool) for tool in used_tools "AnotherTestTool should be present"
), "AnotherTestTool should be present" )
assert not any( assert not any(isinstance(tool, TestTool) for tool in used_tools), (
isinstance(tool, TestTool) for tool in used_tools "TestTool should not be present among used tools"
), "TestTool should not be present among used tools" )
# Confirm delegation tool(s) are present # Confirm delegation tool(s) are present
assert any( assert any("delegate" in tool.name.lower() for tool in used_tools), (
"delegate" in tool.name.lower() for tool in used_tools "Delegation tool should be present"
), "Delegation tool should be present" )
# Finally, make sure the agent's original tools remain unchanged # Finally, make sure the agent's original tools remain unchanged
assert len(researcher_with_delegation.tools) == 1 assert len(researcher_with_delegation.tools) == 1
@@ -1466,7 +1467,6 @@ def test_dont_set_agents_step_callback_if_already_set():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_function_calling_llm(): def test_crew_function_calling_llm():
from crewai import LLM from crewai import LLM
from crewai.tools import tool from crewai.tools import tool
@@ -1560,9 +1560,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
# Verify that exactly one tool was used and it was a CodeInterpreterTool # Verify that exactly one tool was used and it was a CodeInterpreterTool
assert len(used_tools) == 1, "Should have exactly one tool" assert len(used_tools) == 1, "Should have exactly one tool"
assert isinstance( assert isinstance(used_tools[0], CodeInterpreterTool), (
used_tools[0], CodeInterpreterTool "Tool should be CodeInterpreterTool"
), "Tool should be CodeInterpreterTool" )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -3107,9 +3107,9 @@ def test_fetch_inputs():
expected_placeholders = {"role_detail", "topic", "field"} expected_placeholders = {"role_detail", "topic", "field"}
actual_placeholders = crew.fetch_inputs() actual_placeholders = crew.fetch_inputs()
assert ( assert actual_placeholders == expected_placeholders, (
actual_placeholders == expected_placeholders f"Expected {expected_placeholders}, but got {actual_placeholders}"
), f"Expected {expected_placeholders}, but got {actual_placeholders}" )
def test_task_tools_preserve_code_execution_tools(): def test_task_tools_preserve_code_execution_tools():
@@ -3182,20 +3182,20 @@ def test_task_tools_preserve_code_execution_tools():
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Verify all expected tools are present # Verify all expected tools are present
assert any( assert any(isinstance(tool, TestTool) for tool in used_tools), (
isinstance(tool, TestTool) for tool in used_tools "Task's TestTool should be present"
), "Task's TestTool should be present" )
assert any( assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
isinstance(tool, CodeInterpreterTool) for tool in used_tools "CodeInterpreterTool should be present"
), "CodeInterpreterTool should be present" )
assert any( assert any("delegate" in tool.name.lower() for tool in used_tools), (
"delegate" in tool.name.lower() for tool in used_tools "Delegation tool should be present"
), "Delegation tool should be present" )
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools) # Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
assert ( assert len(used_tools) == 4, (
len(used_tools) == 4 "Should have TestTool, CodeInterpreter, and 2 delegation tools"
), "Should have TestTool, CodeInterpreter, and 2 delegation tools" )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -3239,9 +3239,9 @@ def test_multimodal_flag_adds_multimodal_tools():
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Check that the multimodal tool was added # Check that the multimodal tool was added
assert any( assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
isinstance(tool, AddImageTool) for tool in used_tools "AddImageTool should be present when agent is multimodal"
), "AddImageTool should be present when agent is multimodal" )
# Verify we have exactly one tool (just the AddImageTool) # Verify we have exactly one tool (just the AddImageTool)
assert len(used_tools) == 1, "Should only have the AddImageTool" assert len(used_tools) == 1, "Should only have the AddImageTool"
@@ -3467,9 +3467,9 @@ def test_crew_guardrail_feedback_in_context():
assert len(execution_contexts) > 1, "Task should have been executed multiple times" assert len(execution_contexts) > 1, "Task should have been executed multiple times"
# Verify that the second execution included the guardrail feedback # Verify that the second execution included the guardrail feedback
assert ( assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1] "Guardrail feedback should be included in retry context"
), "Guardrail feedback should be included in retry context" )
# Verify final output meets guardrail requirements # Verify final output meets guardrail requirements
assert "IMPORTANT" in result.raw, "Final output should contain required keyword" assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
@@ -3494,7 +3494,6 @@ def test_before_kickoff_callback():
@before_kickoff @before_kickoff
def modify_inputs(self, inputs): def modify_inputs(self, inputs):
self.inputs_modified = True self.inputs_modified = True
inputs["modified"] = True inputs["modified"] = True
return inputs return inputs
@@ -3596,3 +3595,21 @@ def test_before_kickoff_without_inputs():
# Verify that the inputs were initialized and modified inside the before_kickoff method # Verify that the inputs were initialized and modified inside the before_kickoff method
assert test_crew_instance.received_inputs is not None assert test_crew_instance.received_inputs is not None
assert test_crew_instance.received_inputs.get("modified") is True assert test_crew_instance.received_inputs.get("modified") is True
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_with_knowledge_sources_works_with_copy():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
crew = Crew(
agents=[researcher, writer],
tasks=[Task(description="test", expected_output="test", agent=researcher)],
knowledge_sources=[string_source],
)
crew_copy = crew.copy()
assert crew_copy.knowledge_sources == crew.knowledge_sources
assert len(crew_copy.agents) == len(crew.agents)
assert len(crew_copy.tasks) == len(crew.tasks)