mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-25 00:38:30 +00:00
Compare commits
27 Commits
main
...
gl/feat/no
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce8f55f | ||
|
|
1d40f5d83c | ||
|
|
3afac2a696 | ||
|
|
5fab437b7f | ||
|
|
30684f387e | ||
|
|
f2b4efe7fa | ||
|
|
4f175fdd6f | ||
|
|
d72b79f932 | ||
|
|
e8638d318d | ||
|
|
d2c880c6b3 | ||
|
|
087f6d25a9 | ||
|
|
c57e325482 | ||
|
|
fdb7047780 | ||
|
|
adb485f7f7 | ||
|
|
ee64bd426e | ||
|
|
37b80ee937 | ||
|
|
bf9ccd418a | ||
|
|
bd95356ec5 | ||
|
|
441591d592 | ||
|
|
132b6b224a | ||
|
|
4e2916d71a | ||
|
|
0c4a0e1fda | ||
|
|
9c4126e0d8 | ||
|
|
5156fc4792 | ||
|
|
c600b26ca6 | ||
|
|
162a106002 | ||
|
|
be33c8e3e5 |
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"pydantic-settings~=2.10.1",
|
||||
"mcp~=1.16.0",
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
build_task_prompt_with_schema,
|
||||
format_task_with_context,
|
||||
get_knowledge_config,
|
||||
handle_knowledge_retrieval,
|
||||
handle_reasoning,
|
||||
prepare_tools,
|
||||
process_tool_results,
|
||||
save_last_messages,
|
||||
validate_max_execution_time,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.guardrail_types import GuardrailType
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts
|
||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
if self.reasoning:
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
|
||||
# Add the reasoning plan to the task description
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = handle_knowledge_retrieval(
|
||||
self,
|
||||
task,
|
||||
task_prompt,
|
||||
knowledge_config,
|
||||
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||
)
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
# Quering crew specific knowledge
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=self.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
retrieved_knowledge=(
|
||||
(self.agent_knowledge_context or "")
|
||||
+ (
|
||||
"\n"
|
||||
if self.agent_knowledge_context
|
||||
and self.crew_knowledge_context
|
||||
else ""
|
||||
)
|
||||
+ (self.crew_knowledge_context or "")
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=self.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Import agent events locally to avoid circular imports
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
|
||||
# Determine execution method based on timeout setting
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
if (
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
# Propagate TimeoutError without retry
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._last_messages = (
|
||||
self.agent_executor.messages.copy()
|
||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a task with the agent asynchronously.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = await ahandle_knowledge_retrieval(
|
||||
self, task, task_prompt, knowledge_config
|
||||
)
|
||||
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionStartedEvent(
|
||||
agent=self,
|
||||
tools=self.tools,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
result = await self._aexecute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
else:
|
||||
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = await self.aexecute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
async def _aexecute_with_timeout(
|
||||
self, task_prompt: str, task: Task, timeout: int
|
||||
) -> Any:
|
||||
"""Execute a task with a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._aexecute_without_timeout(task_prompt, task),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||
"Consider increasing max_execution_time or optimizing the task."
|
||||
) from e
|
||||
|
||||
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
if not self.agent_executor:
|
||||
raise RuntimeError("Agent executor is not initialized.")
|
||||
|
||||
result = await self.agent_executor.ainvoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
)
|
||||
return result["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||
) -> None:
|
||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions = {}
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types = []
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping = {
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
if platform_tools and self.tools is not None:
|
||||
self.tools.extend(platform_tools)
|
||||
if self.mcps:
|
||||
mcps = self.get_mcp_tools(self.mcps)
|
||||
if mcps:
|
||||
if mcps and self.tools is not None:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
|
||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Utility functions for agent task execution.
|
||||
|
||||
This module contains shared logic extracted from the Agent's execute_task
|
||||
and aexecute_task methods to reduce code duplication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
"""Handle the reasoning process for an agent before task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
if not agent.reasoning:
|
||||
return
|
||||
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||
|
||||
Args:
|
||||
task: The task being executed.
|
||||
task_prompt: The initial task prompt.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with schema instructions.
|
||||
"""
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||
"""Format task prompt with context if provided.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt.
|
||||
context: Optional context string.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt formatted with context if provided.
|
||||
"""
|
||||
if context:
|
||||
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||
"""Get knowledge configuration from agent.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
|
||||
Returns:
|
||||
Dictionary of knowledge configuration.
|
||||
"""
|
||||
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||
|
||||
|
||||
def handle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
query_func: Any,
|
||||
crew_query_func: Any,
|
||||
) -> str:
|
||||
"""Handle knowledge retrieval for task execution.
|
||||
|
||||
This function handles both agent-specific and crew-specific knowledge queries.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
query_func: Function to query agent knowledge (sync or async).
|
||||
crew_query_func: Function to query crew knowledge (sync or async).
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = crew_query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def _combine_knowledge_context(agent: Agent) -> str:
|
||||
"""Combine agent and crew knowledge contexts into a single string.
|
||||
|
||||
Args:
|
||||
agent: The agent with knowledge contexts.
|
||||
|
||||
Returns:
|
||||
Combined knowledge context string.
|
||||
"""
|
||||
agent_ctx = agent.agent_knowledge_context or ""
|
||||
crew_ctx = agent.crew_knowledge_context or ""
|
||||
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||
return agent_ctx + separator + crew_ctx
|
||||
|
||||
|
||||
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""Apply training data to the task prompt.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task_prompt: The task prompt.
|
||||
|
||||
Returns:
|
||||
The task prompt with training data applied.
|
||||
"""
|
||||
if agent.crew and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
"""Process tool results, returning result_as_answer if applicable.
|
||||
|
||||
Args:
|
||||
agent: The agent with tool results.
|
||||
result: The current result.
|
||||
|
||||
Returns:
|
||||
The final result, potentially overridden by tool result_as_answer.
|
||||
"""
|
||||
for tool_result in agent.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
return result
|
||||
|
||||
|
||||
def save_last_messages(agent: Agent) -> None:
|
||||
"""Save the last messages from agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
"""
|
||||
agent._last_messages = (
|
||||
agent.agent_executor.messages.copy()
|
||||
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
def prepare_tools(
|
||||
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||
) -> list[BaseTool]:
|
||||
"""Prepare tools for task execution and create agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
tools: Optional list of tools.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The list of tools to use.
|
||||
"""
|
||||
final_tools = tools or agent.tools or []
|
||||
agent.create_agent_executor(tools=final_tools, task=task)
|
||||
return final_tools
|
||||
|
||||
|
||||
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||
"""Validate max_execution_time parameter.
|
||||
|
||||
Args:
|
||||
max_execution_time: The maximum execution time to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_execution_time is not a positive integer.
|
||||
"""
|
||||
if max_execution_time is not None:
|
||||
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
|
||||
|
||||
async def ahandle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
) -> str:
|
||||
"""Handle async knowledge retrieval for task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
get_llm_response,
|
||||
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent asynchronously with given inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
)
|
||||
user_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("user", "")), inputs
|
||||
)
|
||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
try:
|
||||
formatted_answer = await self._ainvoke_loop()
|
||||
except AssertionError:
|
||||
self._printer.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously until completion.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
break
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
self.agent.security_config.fingerprint
|
||||
)
|
||||
}
|
||||
|
||||
tool_result = await aexecute_tool_and_check_finality(
|
||||
agent_action=formatted_answer,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
crew=self.crew,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
log_error_after=self.log_error_after,
|
||||
printer=self._printer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> AgentAction | AgentFinish:
|
||||
|
||||
@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def set_private_attrs(self) -> Crew:
|
||||
"""set private attributes."""
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
||||
event_listener = EventListener()
|
||||
|
||||
# Determine and set tracing state once for this execution
|
||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@@ -948,6 +948,342 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return list(results)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
Unlike kickoff_async which wraps sync kickoff in a thread, this method
|
||||
uses native async/await for all operations including task execution,
|
||||
memory operations, and knowledge queries.
|
||||
"""
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await self.akickoff(inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
|
||||
)
|
||||
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
self._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
for agent in self.agents:
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
if not agent.function_calling_llm: # type: ignore[attr-defined]
|
||||
agent.function_calling_llm = self.function_calling_llm # type: ignore[attr-defined]
|
||||
|
||||
if not agent.step_callback: # type: ignore[attr-defined]
|
||||
agent.step_callback = self.step_callback # type: ignore[attr-defined]
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
|
||||
if self.process == Process.sequential:
|
||||
result = await self._arun_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result = await self._arun_hierarchical_process()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
)
|
||||
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
async def akickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Native async execution of the Crew's workflow for each input.
|
||||
|
||||
Uses native async throughout rather than thread-based async.
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive.
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.akickoff(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(crew_copy.akickoff(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return list(results)
|
||||
|
||||
async def _arun_sequential_process(self) -> CrewOutput:
|
||||
"""Executes tasks sequentially using native async and returns the final output."""
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _arun_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to complete the tasks using native async."""
|
||||
self._create_manager_agent()
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _aexecute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks using native async and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks to execute
|
||||
start_index: Index to start execution from (for replay)
|
||||
was_replayed: Whether this is a replayed execution
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
continue
|
||||
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = await self._ahandle_conditional_task(
|
||||
task, task_outputs, pending_tasks, task_index, was_replayed
|
||||
)
|
||||
if skipped_task_output:
|
||||
task_outputs.append(skipped_task_output)
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
async_task = asyncio.create_task(
|
||||
task.aexecute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
)
|
||||
)
|
||||
pending_tasks.append((task, async_task, task_index))
|
||||
else:
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(
|
||||
pending_tasks, was_replayed
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = await task.aexecute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
async def _ahandle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
"""Handle conditional task evaluation using native async."""
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
pending_tasks.clear()
|
||||
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
async def _aprocess_async_tasks(
|
||||
self,
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
"""Process pending async tasks and return their outputs."""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, async_task, task_index in pending_tasks:
|
||||
task_output = await async_task
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
self._store_execution_log(
|
||||
future_task, task_output, task_index, was_replayed
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
@@ -1431,6 +1767,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return None
|
||||
|
||||
async def aquery_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||
if self.knowledge:
|
||||
return await self.knowledge.aquery(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
|
||||
@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
finally:
|
||||
detach(flow_token)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""Query across all knowledge sources asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
results_limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
The top results matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
return await self.storage.asearch(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def aadd_sources(self) -> None:
|
||||
"""Add all knowledge sources to storage asynchronously."""
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
await source.aadd()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@abstractmethod
|
||||
async def aadd(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
||||
InputFormat,
|
||||
)
|
||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
||||
DocumentConverter,
|
||||
)
|
||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
||||
HierarchicalChunker,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
||||
DoclingDocument,
|
||||
)
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
# Provide type stubs for when docling is not available
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
)
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add docling content asynchronously."""
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
await self._asave_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add CSV file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _import_dependencies(self):
|
||||
def _import_dependencies(self) -> ModuleType:
|
||||
"""Dynamically import dependencies."""
|
||||
try:
|
||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
return pd
|
||||
return pd # type: ignore[no-any-return]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add Excel file content asynchronously."""
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add JSON file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
content[path] = text
|
||||
return content
|
||||
|
||||
def _import_pdfplumber(self):
|
||||
def _import_pdfplumber(self) -> ModuleType:
|
||||
"""Dynamically import pdfplumber."""
|
||||
try:
|
||||
import pdfplumber
|
||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add PDF file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add string content asynchronously."""
|
||||
new_chunks = self._chunk_text(self.content)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add text file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
|
||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Optional metadata filter for the search.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of search results.
|
||||
"""
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to save.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.adelete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
"""Aggregates and retrieves context from multiple memory sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""Build contextual information for a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
# Fetch all contexts concurrently
|
||||
results = await asyncio.gather(
|
||||
self._afetch_ltm_context(task.description),
|
||||
self._afetch_stm_context(query),
|
||||
self._afetch_entity_context(query),
|
||||
self._afetch_external_context(query),
|
||||
)
|
||||
|
||||
return "\n".join(filter(None, results))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
async def _afetch_stm_context(self, query: str) -> str:
|
||||
"""Fetch recent relevant insights from STM asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted insights as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = await self.stm.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||
"""Fetch historical data from LTM asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
|
||||
Returns:
|
||||
Formatted historical data as bullet points, or None if none found.
|
||||
"""
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
async def _afetch_entity_context(self, query: str) -> str:
|
||||
"""Fetch relevant entity information asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted entity information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = await self.em.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
async def _afetch_external_context(self, query: str) -> str:
|
||||
"""Fetch relevant information from External Memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = await self.exm.asearch(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search entity memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save entity items asynchronously.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors: list[str | None] = []
|
||||
|
||||
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item asynchronously."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
await super(EntityMemory, self).asave(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = await save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
storage: Storage = supported_storages[provider](
|
||||
crew, embedder_config.get("config", {})
|
||||
)
|
||||
return storage
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search external memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to external memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
await super().asave(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
storage: LTMSQLiteStorage | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset long-term memory."""
|
||||
self.storage.reset()
|
||||
|
||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
"""Save a value to memory.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
await self.storage.asave(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
"""Search memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
def set_crew(self, crew: Any) -> Memory:
|
||||
"""Set the crew for this memory instance."""
|
||||
self.crew = crew
|
||||
return self
|
||||
|
||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
await super().asave(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""SQLite storage class for long-term memory data."""
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize the SQLite storage.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file.
|
||||
"""
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initialize the SQLite database and create LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
) -> None:
|
||||
def reset(self) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Save data to the LTM table asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task.
|
||||
metadata: Metadata associated with the memory.
|
||||
datetime: Timestamp of the memory.
|
||||
score: Quality score of the memory.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def aload(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Query the LTM table by task description asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries or None if error occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
cursor = await conn.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the LTM table asynchronously."""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute("DELETE FROM long_term_memories")
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
|
||||
async def aexecute_sync(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task asynchronously using native async/await."""
|
||||
return await self._aexecute_core(agent, context, tools)
|
||||
|
||||
async def _aexecute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task asynchronously."""
|
||||
try:
|
||||
agent = agent or self.agent
|
||||
self.agent = agent
|
||||
if not agent:
|
||||
raise Exception(
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
pydantic_output, json_output = None, None
|
||||
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
for idx, guardrail in enumerate(self._guardrails):
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=guardrail,
|
||||
guardrail_index=idx,
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=self._guardrail,
|
||||
)
|
||||
|
||||
self.output = task_output
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else (
|
||||
pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
async def _ainvoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: GuardrailCallable | None,
|
||||
guardrail_index: int | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Invoke the guardrail function asynchronously."""
|
||||
if not guardrail:
|
||||
return task_output
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||
else:
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
max_attempts = self.guardrail_max_retries + 1
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=guardrail,
|
||||
retry_count=current_retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
|
||||
if guardrail_result.success:
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
return task_output
|
||||
|
||||
if attempt >= self.guardrail_max_retries:
|
||||
guardrail_name = (
|
||||
f"guardrail {guardrail_index}"
|
||||
if guardrail_index is not None
|
||||
else "guardrail"
|
||||
)
|
||||
raise Exception(
|
||||
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count += 1
|
||||
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||
else:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
@@ -242,17 +242,17 @@ def get_llm_response(
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call
|
||||
messages: The messages to send to the LLM
|
||||
callbacks: List of callbacks for the LLM call
|
||||
printer: Printer instance for output
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
executor_context: Optional executor context for hook invocation
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -284,6 +284,60 @@ def get_llm_response(
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
async def aget_llm_response(
|
||||
llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[TokenCalcHandler],
|
||||
printer: Printer,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
if executor_context is not None:
|
||||
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
messages = executor_context.messages
|
||||
|
||||
try:
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent, # type: ignore[arg-type]
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if not answer:
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
) -> AgentAction | AgentFinish:
|
||||
|
||||
@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
||||
# Dummy implementation for MCP tools
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
) -> str:
|
||||
# Dummy async implementation
|
||||
return "Task executed"
|
||||
|
||||
|
||||
def test_base_agent_adapter_initialization():
|
||||
"""Test initialization of the concrete agent adapter."""
|
||||
|
||||
@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||
): ...
|
||||
|
||||
345
lib/crewai/tests/agents/test_async_agent_executor.py
Normal file
345
lib/crewai/tests/agents/test_async_agent_executor.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for async agent executor functionality."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MagicMock:
|
||||
"""Create a mock LLM for testing."""
|
||||
llm = MagicMock()
|
||||
llm.supports_stop_words.return_value = True
|
||||
llm.stop = []
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> MagicMock:
|
||||
"""Create a mock agent for testing."""
|
||||
agent = MagicMock()
|
||||
agent.role = "Test Agent"
|
||||
agent.key = "test_agent_key"
|
||||
agent.verbose = False
|
||||
agent.id = "test_agent_id"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task() -> MagicMock:
|
||||
"""Create a mock task for testing."""
|
||||
task = MagicMock()
|
||||
task.description = "Test task description"
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crew() -> MagicMock:
|
||||
"""Create a mock crew for testing."""
|
||||
crew = MagicMock()
|
||||
crew.verbose = False
|
||||
crew._train = False
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tools_handler() -> MagicMock:
|
||||
"""Create a mock tools handler."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor(
|
||||
mock_llm: MagicMock,
|
||||
mock_agent: MagicMock,
|
||||
mock_task: MagicMock,
|
||||
mock_crew: MagicMock,
|
||||
mock_tools_handler: MagicMock,
|
||||
) -> CrewAgentExecutor:
|
||||
"""Create a CrewAgentExecutor instance for testing."""
|
||||
return CrewAgentExecutor(
|
||||
llm=mock_llm,
|
||||
task=mock_task,
|
||||
crew=mock_crew,
|
||||
agent=mock_agent,
|
||||
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
|
||||
max_iter=5,
|
||||
tools=[],
|
||||
tools_names="",
|
||||
stop_words=["Observation:"],
|
||||
tools_description="",
|
||||
tools_handler=mock_tools_handler,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncAgentExecutor:
|
||||
"""Tests for async agent executor methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_returns_output(self, executor: CrewAgentExecutor) -> None:
|
||||
"""Test that ainvoke returns the expected output."""
|
||||
expected_output = "Final answer from agent"
|
||||
|
||||
with patch.object(
|
||||
executor,
|
||||
"_ainvoke_loop",
|
||||
new_callable=AsyncMock,
|
||||
return_value=AgentFinish(
|
||||
thought="Done", output=expected_output, text="Final Answer: Done"
|
||||
),
|
||||
):
|
||||
with patch.object(executor, "_show_start_logs"):
|
||||
with patch.object(executor, "_create_short_term_memory"):
|
||||
with patch.object(executor, "_create_long_term_memory"):
|
||||
with patch.object(executor, "_create_external_memory"):
|
||||
result = await executor.ainvoke(
|
||||
{
|
||||
"input": "test input",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_loop_calls_aget_llm_response(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that _ainvoke_loop calls aget_llm_response."""
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Thought: I know the answer\nFinal Answer: Test result",
|
||||
) as mock_aget_llm:
|
||||
with patch.object(executor, "_show_logs"):
|
||||
result = await executor._ainvoke_loop()
|
||||
|
||||
mock_aget_llm.assert_called_once()
|
||||
assert isinstance(result, AgentFinish)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_loop_handles_tool_execution(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
) -> None:
|
||||
"""Test that _ainvoke_loop handles tool execution asynchronously."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_response(*args: Any, **kwargs: Any) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return (
|
||||
"Thought: I need to use a tool\n"
|
||||
"Action: test_tool\n"
|
||||
'Action Input: {"arg": "value"}'
|
||||
)
|
||||
return "Thought: I have the answer\nFinal Answer: Tool result processed"
|
||||
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=mock_llm_response,
|
||||
):
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ToolResult(result="Tool executed", result_as_answer=False),
|
||||
) as mock_tool_exec:
|
||||
with patch.object(executor, "_show_logs"):
|
||||
with patch.object(executor, "_handle_agent_action") as mock_handle:
|
||||
mock_handle.return_value = AgentAction(
|
||||
text="Tool result",
|
||||
tool="test_tool",
|
||||
tool_input='{"arg": "value"}',
|
||||
thought="Used tool",
|
||||
result="Tool executed",
|
||||
)
|
||||
result = await executor._ainvoke_loop()
|
||||
|
||||
assert mock_tool_exec.called
|
||||
assert isinstance(result, AgentFinish)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_loop_respects_max_iterations(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that _ainvoke_loop respects max iterations."""
|
||||
executor.max_iter = 2
|
||||
|
||||
async def always_return_action(*args: Any, **kwargs: Any) -> str:
|
||||
return (
|
||||
"Thought: I need to think more\n"
|
||||
"Action: some_tool\n"
|
||||
"Action Input: {}"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=always_return_action,
|
||||
):
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ToolResult(result="Tool result", result_as_answer=False),
|
||||
):
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.handle_max_iterations_exceeded",
|
||||
return_value=AgentFinish(
|
||||
thought="Max iterations",
|
||||
output="Forced answer",
|
||||
text="Max iterations reached",
|
||||
),
|
||||
) as mock_max_iter:
|
||||
with patch.object(executor, "_show_logs"):
|
||||
with patch.object(executor, "_handle_agent_action") as mock_ha:
|
||||
mock_ha.return_value = AgentAction(
|
||||
text="Action",
|
||||
tool="some_tool",
|
||||
tool_input="{}",
|
||||
thought="Thinking",
|
||||
)
|
||||
result = await executor._ainvoke_loop()
|
||||
|
||||
mock_max_iter.assert_called_once()
|
||||
assert isinstance(result, AgentFinish)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_handles_exceptions(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that ainvoke properly propagates exceptions."""
|
||||
with patch.object(executor, "_show_start_logs"):
|
||||
with patch.object(
|
||||
executor,
|
||||
"_ainvoke_loop",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Test error"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await executor.ainvoke(
|
||||
{"input": "test", "tool_names": "", "tools": ""}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_ainvoke_calls(
|
||||
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
|
||||
mock_crew: MagicMock, mock_tools_handler: MagicMock
|
||||
) -> None:
|
||||
"""Test that multiple ainvoke calls can run concurrently."""
|
||||
|
||||
async def create_and_run_executor(executor_id: int) -> dict[str, Any]:
|
||||
executor = CrewAgentExecutor(
|
||||
llm=mock_llm,
|
||||
task=mock_task,
|
||||
crew=mock_crew,
|
||||
agent=mock_agent,
|
||||
prompt={"prompt": "Test {input} {tool_names} {tools}"},
|
||||
max_iter=5,
|
||||
tools=[],
|
||||
tools_names="",
|
||||
stop_words=["Observation:"],
|
||||
tools_description="",
|
||||
tools_handler=mock_tools_handler,
|
||||
)
|
||||
|
||||
async def delayed_response(*args: Any, **kwargs: Any) -> str:
|
||||
await asyncio.sleep(0.05)
|
||||
return f"Thought: Done\nFinal Answer: Result from executor {executor_id}"
|
||||
|
||||
with patch(
|
||||
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=delayed_response,
|
||||
):
|
||||
with patch.object(executor, "_show_start_logs"):
|
||||
with patch.object(executor, "_show_logs"):
|
||||
with patch.object(executor, "_create_short_term_memory"):
|
||||
with patch.object(executor, "_create_long_term_memory"):
|
||||
with patch.object(executor, "_create_external_memory"):
|
||||
return await executor.ainvoke(
|
||||
{
|
||||
"input": f"test {executor_id}",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
results = await asyncio.gather(
|
||||
create_and_run_executor(1),
|
||||
create_and_run_executor(2),
|
||||
create_and_run_executor(3),
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert len(results) == 3
|
||||
assert all("output" in r for r in results)
|
||||
assert elapsed < 0.15, f"Expected concurrent execution, took {elapsed}s"
|
||||
|
||||
|
||||
class TestAsyncLLMResponseHelper:
|
||||
"""Tests for aget_llm_response helper function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_llm_response_calls_acall(self) -> None:
|
||||
"""Test that aget_llm_response calls llm.acall."""
|
||||
from crewai.utilities.agent_utils import aget_llm_response
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.acall = AsyncMock(return_value="LLM response")
|
||||
|
||||
result = await aget_llm_response(
|
||||
llm=mock_llm,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
callbacks=[],
|
||||
printer=Printer(),
|
||||
)
|
||||
|
||||
mock_llm.acall.assert_called_once()
|
||||
assert result == "LLM response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_llm_response_raises_on_empty_response(self) -> None:
|
||||
"""Test that aget_llm_response raises ValueError on empty response."""
|
||||
from crewai.utilities.agent_utils import aget_llm_response
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.acall = AsyncMock(return_value="")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response from LLM call"):
|
||||
await aget_llm_response(
|
||||
llm=mock_llm,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
callbacks=[],
|
||||
printer=Printer(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_llm_response_propagates_exceptions(self) -> None:
|
||||
"""Test that aget_llm_response propagates LLM exceptions."""
|
||||
from crewai.utilities.agent_utils import aget_llm_response
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.acall = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="LLM error"):
|
||||
await aget_llm_response(
|
||||
llm=mock_llm,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
callbacks=[],
|
||||
printer=Printer(),
|
||||
)
|
||||
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Tests for async crew execution."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent() -> Agent:
|
||||
"""Create a test agent."""
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_task(test_agent: Agent) -> Task:
|
||||
"""Create a test task."""
|
||||
return Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_crew(test_agent: Agent, test_task: Task) -> Crew:
|
||||
"""Create a test crew."""
|
||||
return Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[test_task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncCrewKickoff:
|
||||
"""Tests for async crew kickoff methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_basic(
|
||||
self, mock_execute: AsyncMock, test_crew: Crew
|
||||
) -> None:
|
||||
"""Test basic async crew kickoff."""
|
||||
mock_output = TaskOutput(
|
||||
description="Test task description",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
result = await test_crew.akickoff()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
assert result.raw == "Task result"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_with_inputs(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async crew kickoff with inputs."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output for {topic}",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task for AI",
|
||||
raw="Task result about AI",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
result = await crew.akickoff(inputs={"topic": "AI"})
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_multiple_tasks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async crew kickoff with multiple tasks."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First output",
|
||||
agent=test_agent,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="First task",
|
||||
raw="First result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Second task",
|
||||
raw="Second result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew.akickoff()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
assert result.raw == "Second result"
|
||||
assert mock_execute.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_handles_exception(
|
||||
self, mock_execute: AsyncMock, test_crew: Crew
|
||||
) -> None:
|
||||
"""Test that async kickoff handles exceptions properly."""
|
||||
mock_execute.side_effect = RuntimeError("Test error")
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await test_crew.akickoff()
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_calls_before_callbacks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff calls before_kickoff_callbacks."""
|
||||
callback_called = False
|
||||
|
||||
def before_callback(inputs: dict | None) -> dict:
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
return inputs or {}
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
before_kickoff_callbacks=[before_callback],
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
await crew.akickoff()
|
||||
|
||||
assert callback_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_calls_after_callbacks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff calls after_kickoff_callbacks."""
|
||||
callback_called = False
|
||||
|
||||
def after_callback(result: CrewOutput) -> CrewOutput:
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
return result
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
after_kickoff_callbacks=[after_callback],
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
await crew.akickoff()
|
||||
|
||||
assert callback_called
|
||||
|
||||
|
||||
class TestAsyncCrewKickoffForEach:
|
||||
"""Tests for async crew kickoff_for_each methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_for_each_basic(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test basic async kickoff_for_each."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="Test task for AI",
|
||||
raw="Result about AI",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Test task for ML",
|
||||
raw="Result about ML",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
inputs = [{"topic": "AI"}, {"topic": "ML"}]
|
||||
results = await crew.akickoff_for_each(inputs)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(isinstance(r, CrewOutput) for r in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_for_each_concurrent(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff_for_each runs concurrently."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
inputs = [{"topic": f"topic_{i}"} for i in range(3)]
|
||||
results = await crew.akickoff_for_each(inputs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
class TestAsyncTaskExecution:
|
||||
"""Tests for async task execution within crew."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_aexecute_tasks_sequential(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async sequential task execution."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First output",
|
||||
agent=test_agent,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="First task",
|
||||
raw="First result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Second task",
|
||||
raw="Second result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew._aexecute_tasks(crew.tasks)
|
||||
|
||||
assert result is not None
|
||||
assert result.raw == "Second result"
|
||||
assert len(result.tasks_output) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_aexecute_tasks_with_async_task(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with async_execution task flag."""
|
||||
task1 = Task(
|
||||
description="Async task",
|
||||
expected_output="Async output",
|
||||
agent=test_agent,
|
||||
async_execution=True,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Sync task",
|
||||
expected_output="Sync output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="Async task",
|
||||
raw="Async result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Sync task",
|
||||
raw="Sync result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew._aexecute_tasks(crew.tasks)
|
||||
|
||||
assert result is not None
|
||||
assert mock_execute.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncProcessAsyncTasks:
|
||||
"""Tests for _aprocess_async_tasks method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aprocess_async_tasks_empty(self, test_crew: Crew) -> None:
|
||||
"""Test processing empty list of async tasks."""
|
||||
result = await test_crew._aprocess_async_tasks([])
|
||||
assert result == []
|
||||
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Tests for async knowledge operations."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
class TestAsyncKnowledgeStorage:
|
||||
"""Tests for async KnowledgeStorage operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_returns_results(self):
|
||||
"""Test that asearch returns search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.asearch = AsyncMock(
|
||||
return_value=[{"content": "test result", "score": 0.9}]
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
results = await storage.asearch(["test query"])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "test result"
|
||||
mock_client.asearch.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_empty_query_raises_error(self):
|
||||
"""Test that asearch handles empty query."""
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
|
||||
# Empty query should not raise but return empty results due to error handling
|
||||
results = await storage.asearch([])
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_calls_client_methods(self):
|
||||
"""Test that asave calls the correct client methods."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
await storage.asave(["document 1", "document 2"])
|
||||
|
||||
mock_client.aget_or_create_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_collection"
|
||||
)
|
||||
mock_client.aadd_documents.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_calls_client_delete(self):
|
||||
"""Test that areset calls delete_collection on the client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.adelete_collection = AsyncMock()
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
await storage.areset()
|
||||
|
||||
mock_client.adelete_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_collection"
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncKnowledge:
|
||||
"""Tests for async Knowledge operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aquery_calls_storage_asearch(self):
|
||||
"""Test that aquery calls storage.asearch."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asearch = AsyncMock(
|
||||
return_value=[{"content": "result", "score": 0.8}]
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
results = await knowledge.aquery(["test query"])
|
||||
|
||||
assert len(results) == 1
|
||||
mock_storage.asearch.assert_called_once_with(
|
||||
["test query"],
|
||||
limit=5,
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aquery_raises_when_storage_not_initialized(self):
|
||||
"""Test that aquery raises ValueError when storage is None."""
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=MagicMock(spec=KnowledgeStorage),
|
||||
)
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
await knowledge.aquery(["test query"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_sources_calls_source_aadd(self):
|
||||
"""Test that aadd_sources calls aadd on each source."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_source = MagicMock()
|
||||
mock_source.aadd = AsyncMock()
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[mock_source],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
await knowledge.aadd_sources()
|
||||
|
||||
mock_source.aadd.assert_called_once()
|
||||
assert mock_source.storage == mock_storage
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_calls_storage_areset(self):
|
||||
"""Test that areset calls storage.areset."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.areset = AsyncMock()
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
await knowledge.areset()
|
||||
|
||||
mock_storage.areset.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_raises_when_storage_not_initialized(self):
|
||||
"""Test that areset raises ValueError when storage is None."""
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=MagicMock(spec=KnowledgeStorage),
|
||||
)
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
await knowledge.areset()
|
||||
|
||||
|
||||
class TestAsyncStringKnowledgeSource:
|
||||
"""Tests for async StringKnowledgeSource operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_saves_documents_asynchronously(self):
|
||||
"""Test that aadd chunks and saves documents asynchronously."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asave = AsyncMock()
|
||||
|
||||
source = StringKnowledgeSource(content="Test content for async processing")
|
||||
source.storage = mock_storage
|
||||
|
||||
await source.aadd()
|
||||
|
||||
mock_storage.asave.assert_called_once()
|
||||
assert len(source.chunks) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_raises_without_storage(self):
|
||||
"""Test that aadd raises ValueError when storage is not set."""
|
||||
source = StringKnowledgeSource(content="Test content")
|
||||
source.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="No storage found"):
|
||||
await source.aadd()
|
||||
|
||||
|
||||
class TestAsyncBaseKnowledgeSource:
|
||||
"""Tests for async _asave_documents method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_documents_calls_storage_asave(self):
|
||||
"""Test that _asave_documents calls storage.asave."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asave = AsyncMock()
|
||||
|
||||
source = StringKnowledgeSource(content="Test")
|
||||
source.storage = mock_storage
|
||||
source.chunks = ["chunk1", "chunk2"]
|
||||
|
||||
await source._asave_documents()
|
||||
|
||||
mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_documents_raises_without_storage(self):
|
||||
"""Test that _asave_documents raises ValueError when storage is None."""
|
||||
source = StringKnowledgeSource(content="Test")
|
||||
source.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="No storage found"):
|
||||
await source._asave_documents()
|
||||
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""Tests for async memory operations."""
|
||||
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Fixture to create a mock agent."""
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task(mock_agent):
|
||||
"""Fixture to create a mock task."""
|
||||
return Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=mock_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory(mock_agent, mock_task):
|
||||
"""Fixture to create a ShortTermMemory instance."""
|
||||
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(tmp_path):
|
||||
"""Fixture to create a LongTermMemory instance."""
|
||||
db_path = str(tmp_path / "test_ltm.db")
|
||||
return LongTermMemory(path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||
"""Fixture to create an EntityMemory instance."""
|
||||
return EntityMemory(
|
||||
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||
path=str(tmp_path / "test_entities"),
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncShortTermMemory:
|
||||
"""Tests for async ShortTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, short_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
await short_term_memory.asave(
|
||||
value="async test value",
|
||||
metadata={"task": "async_test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, short_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await short_term_memory.asearch(
|
||||
query="async test query",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
|
||||
class TestAsyncLongTermMemory:
|
||||
"""Tests for async LongTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, long_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = LongTermMemoryItem(
|
||||
task="async test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "async test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, long_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||
"""Test that asave followed by asearch works correctly."""
|
||||
item = LongTermMemoryItem(
|
||||
task="integration test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "integration test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||
|
||||
assert results is not None
|
||||
assert len(results) == 1
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
|
||||
class TestAsyncEntityMemory:
|
||||
"""Tests for async EntityMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||
"""Test that asave with a single item emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = EntityMemoryItem(
|
||||
name="TestEntity",
|
||||
type="Person",
|
||||
description="A test entity for async operations",
|
||||
relationships="Related to other test entities",
|
||||
)
|
||||
|
||||
await entity_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, entity_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
|
||||
class TestAsyncContextualMemory:
|
||||
"""Tests for async ContextualMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||
mock_task.description = ""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||
"""Test that abuild_context_for_task handles None memory sources."""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||
)
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "STM insight" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "LTM suggestion" in result
|
||||
assert "Entities:" in result
|
||||
assert "Entity info" in result
|
||||
assert "External memories:" in result
|
||||
assert "External memory" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "First insight"},
|
||||
{"content": "Second insight"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_stm_context("test query")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "- First insight" in result
|
||||
assert "- Second insight" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=mock_ltm,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_ltm_context("test task")
|
||||
|
||||
assert "Historical Data:" in result
|
||||
assert "- Suggestion 1" in result
|
||||
assert "- Suggestion 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "Entity A details"},
|
||||
{"content": "Entity B details"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=mock_em,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_entity_context("test query")
|
||||
|
||||
assert "Entities:" in result
|
||||
assert "- Entity A details" in result
|
||||
assert "- Entity B details" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_external_context_returns_formatted_results(self):
|
||||
"""Test that _afetch_external_context returns properly formatted results."""
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "External data 1"},
|
||||
{"content": "External data 2"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_external_context("test query")
|
||||
|
||||
assert "External memories:" in result
|
||||
assert "- External data 1" in result
|
||||
assert "- External data 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||
"""Test that async fetch methods return empty string for no results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||
em_result = await contextual_memory._afetch_entity_context("query")
|
||||
exm_result = await contextual_memory._afetch_external_context("query")
|
||||
|
||||
assert stm_result == ""
|
||||
assert ltm_result is None
|
||||
assert em_result == ""
|
||||
assert exm_result == ""
|
||||
386
lib/crewai/tests/task/test_async_task.py
Normal file
386
lib/crewai/tests/task/test_async_task.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Tests for async task execution."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent() -> Agent:
|
||||
"""Create a test agent."""
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncTaskExecution:
|
||||
"""Tests for async task execution methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_basic(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test basic async task execution."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "Async task result"
|
||||
assert result.agent == "Test Agent"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_context(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async task execution with context."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
context = "Additional context for the task"
|
||||
result = await task.aexecute_sync(context=context)
|
||||
|
||||
assert result is not None
|
||||
assert task.prompt_context == context
|
||||
mock_execute.assert_called_once()
|
||||
call_kwargs = mock_execute.call_args[1]
|
||||
assert call_kwargs["context"] == context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_tools(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async task execution with custom tools."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
|
||||
result = await task.aexecute_sync(tools=[mock_tool])
|
||||
|
||||
assert result is not None
|
||||
mock_execute.assert_called_once()
|
||||
call_kwargs = mock_execute.call_args[1]
|
||||
assert mock_tool in call_kwargs["tools"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_sets_start_and_end_time(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution sets start and end times."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert task.start_time is None
|
||||
assert task.end_time is None
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert task.start_time is not None
|
||||
assert task.end_time is not None
|
||||
assert task.end_time >= task.start_time
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_stores_output(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution stores the output."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert task.output is None
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert task.output is not None
|
||||
assert task.output.raw == "Async task result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_adds_agent_to_processed_by(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution adds agent to processed_by_agents."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert len(task.processed_by_agents) == 0
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "Test Agent" in task.processed_by_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_calls_callback(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution calls the callback."""
|
||||
mock_execute.return_value = "Async result"
|
||||
callback = MagicMock()
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
callback.assert_called_once()
|
||||
assert isinstance(callback.call_args[0][0], TaskOutput)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aexecute_sync_without_agent_raises(self) -> None:
|
||||
"""Test that async execution without agent raises exception."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "has no agent assigned" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_different_agent(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with a different agent than assigned."""
|
||||
mock_execute.return_value = "Other agent result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
other_agent = Agent(
|
||||
role="Other Agent",
|
||||
goal="Other goal",
|
||||
backstory="Other backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync(agent=other_agent)
|
||||
|
||||
assert result.raw == "Other agent result"
|
||||
assert result.agent == "Other Agent"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_handles_exception(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution handles exceptions properly."""
|
||||
mock_execute.side_effect = RuntimeError("Test error")
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
assert task.end_time is not None
|
||||
|
||||
|
||||
class TestAsyncGuardrails:
|
||||
"""Tests for async guardrail invocation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_success(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail invocation with successful validation."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw == "Async task result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_failure_then_success(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail that fails then succeeds on retry."""
|
||||
mock_execute.side_effect = ["First result", "Second result"]
|
||||
call_count = 0
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return False, "First attempt failed"
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_max_retries_exceeded(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail that exceeds max retries."""
|
||||
mock_execute.return_value = "Async result"
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
return False, "Always fails"
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "validation after" in str(exc_info.value)
|
||||
assert "2 retries" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_multiple_guardrails(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with multiple guardrails."""
|
||||
mock_execute.return_value = "Async result"
|
||||
guardrail1_called = False
|
||||
guardrail2_called = False
|
||||
|
||||
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal guardrail1_called
|
||||
guardrail1_called = True
|
||||
return True, output.raw
|
||||
|
||||
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal guardrail2_called
|
||||
guardrail2_called = True
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrails=[guardrail1, guardrail2],
|
||||
)
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert guardrail1_called
|
||||
assert guardrail2_called
|
||||
|
||||
|
||||
class TestAsyncTaskOutput:
|
||||
"""Tests for async task output handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_output_format_raw(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with raw output format."""
|
||||
mock_execute.return_value = '{"key": "value"}'
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result.output_format == OutputFormat.RAW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_task_output_attributes(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that task output has correct attributes."""
|
||||
mock_execute.return_value = "Test result"
|
||||
task = Task(
|
||||
description="Test description",
|
||||
expected_output="Test expected",
|
||||
agent=test_agent,
|
||||
name="Test Task Name",
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result.name == "Test Task Name"
|
||||
assert result.description == "Test description"
|
||||
assert result.expected_output == "Test expected"
|
||||
assert result.raw == "Test result"
|
||||
assert result.agent == "Test Agent"
|
||||
@@ -1492,3 +1492,144 @@ def test_flow_copy_state_with_dict_state():
|
||||
|
||||
flow.state["test"] = "modified"
|
||||
assert copied_state["test"] == "value"
|
||||
|
||||
|
||||
class TestFlowAkickoff:
|
||||
"""Tests for the native async akickoff method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_basic(self):
|
||||
"""Test basic akickoff execution."""
|
||||
execution_order = []
|
||||
|
||||
class SimpleFlow(Flow):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
return "step_1_result"
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self, result):
|
||||
execution_order.append("step_2")
|
||||
return "final_result"
|
||||
|
||||
flow = SimpleFlow()
|
||||
result = await flow.akickoff()
|
||||
|
||||
assert execution_order == ["step_1", "step_2"]
|
||||
assert result == "final_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_with_inputs(self):
|
||||
"""Test akickoff with inputs."""
|
||||
|
||||
class InputFlow(Flow):
|
||||
@start()
|
||||
def process_input(self):
|
||||
return self.state.get("value", "default")
|
||||
|
||||
flow = InputFlow()
|
||||
result = await flow.akickoff(inputs={"value": "custom_value"})
|
||||
|
||||
assert result == "custom_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_with_async_methods(self):
|
||||
"""Test akickoff with async flow methods."""
|
||||
execution_order = []
|
||||
|
||||
class AsyncMethodFlow(Flow):
|
||||
@start()
|
||||
async def async_step_1(self):
|
||||
execution_order.append("async_step_1")
|
||||
await asyncio.sleep(0.01)
|
||||
return "async_result"
|
||||
|
||||
@listen(async_step_1)
|
||||
async def async_step_2(self, result):
|
||||
execution_order.append("async_step_2")
|
||||
await asyncio.sleep(0.01)
|
||||
return f"final_{result}"
|
||||
|
||||
flow = AsyncMethodFlow()
|
||||
result = await flow.akickoff()
|
||||
|
||||
assert execution_order == ["async_step_1", "async_step_2"]
|
||||
assert result == "final_async_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_equivalent_to_kickoff_async(self):
|
||||
"""Test that akickoff produces the same results as kickoff_async."""
|
||||
execution_order_akickoff = []
|
||||
execution_order_kickoff_async = []
|
||||
|
||||
class TestFlow(Flow):
|
||||
def __init__(self, execution_list):
|
||||
super().__init__()
|
||||
self._execution_list = execution_list
|
||||
|
||||
@start()
|
||||
def step_1(self):
|
||||
self._execution_list.append("step_1")
|
||||
return "result_1"
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self, result):
|
||||
self._execution_list.append("step_2")
|
||||
return "result_2"
|
||||
|
||||
flow1 = TestFlow(execution_order_akickoff)
|
||||
result1 = await flow1.akickoff()
|
||||
|
||||
flow2 = TestFlow(execution_order_kickoff_async)
|
||||
result2 = await flow2.kickoff_async()
|
||||
|
||||
assert execution_order_akickoff == execution_order_kickoff_async
|
||||
assert result1 == result2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_with_multiple_starts(self):
|
||||
"""Test akickoff with multiple start methods."""
|
||||
execution_order = []
|
||||
|
||||
class MultiStartFlow(Flow):
|
||||
@start()
|
||||
def start_a(self):
|
||||
execution_order.append("start_a")
|
||||
|
||||
@start()
|
||||
def start_b(self):
|
||||
execution_order.append("start_b")
|
||||
|
||||
flow = MultiStartFlow()
|
||||
await flow.akickoff()
|
||||
|
||||
assert "start_a" in execution_order
|
||||
assert "start_b" in execution_order
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_akickoff_with_router(self):
|
||||
"""Test akickoff with router method."""
|
||||
execution_order = []
|
||||
|
||||
class RouterFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
execution_order.append("begin")
|
||||
return "data"
|
||||
|
||||
@router(begin)
|
||||
def route(self, data):
|
||||
execution_order.append("route")
|
||||
return "PATH_A"
|
||||
|
||||
@listen("PATH_A")
|
||||
def handle_path_a(self):
|
||||
execution_order.append("path_a")
|
||||
return "path_a_result"
|
||||
|
||||
flow = RouterFlow()
|
||||
result = await flow.akickoff()
|
||||
|
||||
assert execution_order == ["begin", "route", "path_a"]
|
||||
assert result == "path_a_result"
|
||||
|
||||
52
uv.lock
generated
52
uv.lock
generated
@@ -247,6 +247,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosqlite"
|
||||
version = "0.21.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.4"
|
||||
@@ -1100,6 +1112,7 @@ wheels = [
|
||||
name = "crewai"
|
||||
source = { editable = "lib/crewai" }
|
||||
dependencies = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "appdirs" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "click" },
|
||||
@@ -1183,6 +1196,7 @@ watson = [
|
||||
requires-dist = [
|
||||
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||
@@ -1843,7 +1857,7 @@ name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
@@ -4337,7 +4351,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.10.2.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||
@@ -4348,7 +4362,7 @@ name = "nvidia-cufft-cu12"
|
||||
version = "11.3.3.83"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||
@@ -4375,9 +4389,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.7.3.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cusparse-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||
@@ -4388,7 +4402,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.5.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||
@@ -4448,9 +4462,9 @@ name = "ocrmac"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pyobjc-framework-vision" },
|
||||
{ name = "click", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-vision", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dd/dc/de3e9635774b97d9766f6815bbb3f5ec9bce347115f10d9abbf2733a9316/ocrmac-1.0.0.tar.gz", hash = "sha256:5b299e9030c973d1f60f82db000d6c2e5ff271601878c7db0885e850597d1d2e", size = 1463997, upload-time = "2024-11-07T12:00:00.197Z" }
|
||||
wheels = [
|
||||
@@ -6050,7 +6064,7 @@ name = "pyobjc-framework-cocoa"
|
||||
version = "12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyobjc-core" },
|
||||
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" }
|
||||
wheels = [
|
||||
@@ -6066,8 +6080,8 @@ name = "pyobjc-framework-coreml"
|
||||
version = "12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyobjc-core" },
|
||||
{ name = "pyobjc-framework-cocoa" },
|
||||
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/30/2d/baa9ea02cbb1c200683cb7273b69b4bee5070e86f2060b77e6a27c2a9d7e/pyobjc_framework_coreml-12.1.tar.gz", hash = "sha256:0d1a4216891a18775c9e0170d908714c18e4f53f9dc79fb0f5263b2aa81609ba", size = 40465, upload-time = "2025-11-14T10:14:02.265Z" }
|
||||
wheels = [
|
||||
@@ -6083,8 +6097,8 @@ name = "pyobjc-framework-quartz"
|
||||
version = "12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyobjc-core" },
|
||||
{ name = "pyobjc-framework-cocoa" },
|
||||
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" }
|
||||
wheels = [
|
||||
@@ -6100,10 +6114,10 @@ name = "pyobjc-framework-vision"
|
||||
version = "12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyobjc-core" },
|
||||
{ name = "pyobjc-framework-cocoa" },
|
||||
{ name = "pyobjc-framework-coreml" },
|
||||
{ name = "pyobjc-framework-quartz" },
|
||||
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-coreml", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c2/5a/08bb3e278f870443d226c141af14205ff41c0274da1e053b72b11dfc9fb2/pyobjc_framework_vision-12.1.tar.gz", hash = "sha256:a30959100e85dcede3a786c544e621ad6eb65ff6abf85721f805822b8c5fe9b0", size = 59538, upload-time = "2025-11-14T10:23:21.979Z" }
|
||||
wheels = [
|
||||
|
||||
Reference in New Issue
Block a user